@@ -45,14 +45,22 @@ class SymbolModality(modality.Modality):
4545 def name (self ):
4646 return "symbol_modality_%d_%d" % (self ._vocab_size , self ._body_input_depth )
4747
48- @property
49- def top_dimensionality (self ):
50- return self ._vocab_size
51-
5248 @property
5349 def top_is_pointwise (self ):
5450 return True
5551
52+ @property
53+ def weights_fn (self ):
54+ weights_fn = common_layers .weights_nonzero
55+
56+ hp = self ._model_hparams
57+ if hp and hp .prepend_mode != "none" :
58+ assert (hp .prepend_mode == "prepend_inputs_masked_attention" or
59+ hp .prepend_mode == "prepend_inputs_full_attention" )
60+ weights_fn = common_layers .weights_prepend_inputs_to_targets
61+
62+ return weights_fn
63+
5664 def _get_weights (self , hidden_dim = None ):
5765 """Create or get concatenated embedding or softmax variable.
5866
@@ -151,7 +159,7 @@ def top(self, body_output, _):
151159class CTCSymbolModality (SymbolModality ):
152160 """SymbolModality that uses CTC loss."""
153161
154- def loss (self , logits , targets , weights_fn = common_layers . weights_nonzero ):
162+ def loss (self , logits , targets ):
155163 """Compute the CTC loss."""
156164 with tf .name_scope ("ctc_loss" , [logits , targets ]):
157165 # For CTC we assume targets are 1d, [batch, length, 1, 1] here.
@@ -172,21 +180,14 @@ def loss(self, logits, targets, weights_fn=common_layers.weights_nonzero):
172180 time_major = False ,
173181 preprocess_collapse_repeated = False ,
174182 ctc_merge_repeated = False )
175- weights = weights_fn (targets )
183+ weights = self . targets_weights_fn (targets )
176184 return tf .reduce_sum (xent ), tf .reduce_sum (weights )
177185
178186
179187@registry .register_image_modality ("default" )
180188class ImageModality (modality .Modality ):
181189 """Modality for images."""
182-
183- def __init__ (self , model_hparams , vocab_size ):
184- super (ImageModality , self ).__init__ (model_hparams , vocab_size )
185- self ._channels = 3
186-
187- @property
188- def top_dimensionality (self ):
189- return 256
190+ NUM_CHANNELS = 3
190191
191192 def bottom (self , inputs ):
192193 with tf .variable_scope (self .name ):
@@ -217,7 +218,7 @@ def top(self, body_output, _):
217218 common_layers .shape_dim (body_output , i ) for i in range (3 )
218219 ]
219220 dim = body_output .get_shape ().as_list ()[- 1 ] // 3
220- reshape_shape .extend ([self ._channels , dim ])
221+ reshape_shape .extend ([self .NUM_CHANNELS , dim ])
221222
222223 out = tf .reshape (body_output , reshape_shape )
223224 res = tf .layers .dense (out , self .top_dimensionality )
@@ -226,21 +227,11 @@ def top(self, body_output, _):
226227 tf .summary .image ("result" , res_argmax , max_outputs = 1 )
227228 return res
228229
229- def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
230- # Call the default implementation, but weight 1.0 on 0s by default.
231- # (Since we're processing images and so have no padding and some pixel 0s.)
232- return super (ImageModality , self ).loss (
233- top_out , targets , weights_fn = weights_fn )
234-
235230
236231@registry .register_image_modality ("image_identity_compress" )
237232class ImageIdentityCompressModality (modality .Modality ):
238233 """Modality for images used in generation."""
239234
240- @property
241- def top_dimensionality (self ):
242- return 256
243-
244235 def bottom_compress (self , inputs , name = "bottom" ):
245236 """Transform input from data space to model space.
246237
@@ -296,12 +287,6 @@ def top(self, body_output, _):
296287 channels , self .top_dimensionality ])
297288 return x
298289
299- def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
300- # Call the default implementation, but weight 1.0 on 0s by default.
301- # (Since we're processing images and so have no padding and some pixel 0s.)
302- return super (ImageIdentityCompressModality , self ).loss (
303- top_out , targets , weights_fn = weights_fn )
304-
305290
306291@registry .register_audio_modality ("default" )
307292class AudioModality (modality .Modality ):
@@ -399,10 +384,6 @@ def name(self):
399384 return "class_label_modality_%d_%d" % (self ._vocab_size ,
400385 self ._body_input_depth )
401386
402- @property
403- def top_dimensionality (self ):
404- return self ._vocab_size
405-
406387 def bottom (self , x ):
407388 with tf .variable_scope (self .name ):
408389 return common_layers .embedding (
@@ -434,12 +415,6 @@ def top(self, body_output, _):
434415 res = tf .layers .dense (x , self ._vocab_size )
435416 return tf .expand_dims (res , 3 )
436417
437- def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
438- # Call the default implementation, but weight 1.0 on 0s by default.
439- # (Since we're processing images and so have no padding and some pixel 0s.)
440- return super (ClassLabelModality , self ).loss (
441- top_out , targets , weights_fn = weights_fn )
442-
443418
444419@registry .register_generic_modality ("default" )
445420@registry .register_audio_modality ("identity" )
@@ -450,10 +425,6 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
450425class IdentityModality (modality .Modality ):
451426 """Does nothing."""
452427
453- @property
454- def targets_dimensionality (self ):
455- return self ._vocab_size
456-
457428 def bottom (self , x ):
458429 return tf .to_float (x )
459430
@@ -476,7 +447,7 @@ def top(self, body_output, _):
476447 with tf .variable_scope ("real" ):
477448 return tf .layers .dense (body_output , self ._vocab_size )
478449
479- def loss (self , top_out , targets , weights_fn = common_layers . weights_all ):
450+ def loss (self , top_out , targets ):
480451 raise NotImplementedError ()
481452
482453
@@ -485,70 +456,35 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
485456class RealL2LossModality (RealModality ):
486457 """Modality for real (i.e. float) vectors with L2 (Gaussian) loss."""
487458
488- def loss (self , top_out , targets , weights_fn = common_layers . weights_all ):
459+ def loss (self , top_out , targets ):
489460 predictions = top_out
490461 with tf .name_scope ("l2" ):
491- weights = weights_fn (targets )
462+ weights = self . targets_weights_fn (targets )
492463 l2 = tf .pow (predictions - targets , 2 )
493464 return tf .reduce_sum (l2 * weights ), tf .reduce_sum (weights )
494465
495466
496467@registry .register_real_modality ("log_poisson_loss" )
497- class RealLogPoissonLossModality (RealL2LossModality ):
498- """Modality for real (i.e. float) vectors with log Poisson regression loss.
499- """
500-
501- def bottom (self , x ):
502- return x
468+ class RealLogPoissonLossModality (RealModality ):
469+ """Modality for real (i.e. float) vectors with log Poisson regression loss."""
503470
504- def loss (self , top_out , targets , weights_fn = common_layers . weights_all ):
471+ def loss (self , top_out , targets ):
505472 predictions = top_out
506473 with tf .name_scope ("log_possion" ):
507- weights = weights_fn (targets )
474+ weights = self . targets_weights_fn (targets )
508475
509476 lp_loss = tf .nn .log_poisson_loss (targets , predictions )
510477 return tf .reduce_sum (lp_loss * weights ), tf .reduce_sum (weights )
511478
512479
513- @registry .register_image_modality ("identity_no_pad" )
514- class IdentityModalityNoPad (modality .Modality ):
515- """Does nothing except making sure that there is no padding in cross-ent."""
516-
517- @property
518- def top_dimensionality (self ):
519- return 256
520-
521- @property
522- def targets_dimensionality (self ):
523- return self ._vocab_size
524-
525- def bottom (self , x ):
526- return tf .to_float (x )
527-
528- def top (self , body_output , _ ):
529- return body_output
530-
531- def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
532- # Call the default implementation, but weight 1.0 on 0s by default.
533- # (Since we're processing images and so have no padding and some pixel 0s.)
534- return super (IdentityModalityNoPad , self ).loss (
535- top_out , targets , weights_fn = weights_fn )
536-
537-
538- @registry .register_image_modality ("no_loss" )
539- class NoLossModality (modality .Modality ):
540- """Does nothing to the input and returns no loss."""
541-
542- @property
543- def targets_dimensionality (self ):
544- return self ._vocab_size
545-
546- def bottom (self , x ):
547- return tf .to_float (x )
548-
549- def top (self , body_output , _ ):
550- return body_output
480+ @registry .register_generic_modality ("zero_loss" )
481+ @registry .register_audio_modality ("zero_loss" )
482+ @registry .register_image_modality ("zero_loss" )
483+ @registry .register_symbol_modality ("zero_loss" )
484+ @registry .register_class_label_modality ("zero_loss" )
485+ @registry .register_real_modality ("zero_loss" )
486+ class IdentityZeroLossModality (IdentityModality ):
487+ """Identity with 0 loss."""
551488
552- def loss_sharded (self , sharded_top_out , sharded_targets , data_parallelism ):
553- """Return nothing."""
554- return tf .constant (0.0 , tf .float32 )
489+ def loss (self , top_out , targets ):
490+ return tf .constant (0. , tf .float32 ), tf .constant (0. , tf .float32 )
0 commit comments