Skip to content

Commit 50f5515

Browse files
author
Ryan Sepassi
committed
Add Modality.targets_weights_fn
PiperOrigin-RevId: 175722118
1 parent 6cf47f9 commit 50f5515

6 files changed

Lines changed: 79 additions & 130 deletions

File tree

tensor2tensor/data_generators/image.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ def preprocess_example(self, example, unused_mode, unused_hparams):
112112

113113
def hparams(self, defaults, unused_model_hparams):
114114
p = defaults
115-
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
116-
p.target_modality = ("image:identity_no_pad", None)
115+
p.input_modality = {"inputs": ("image:identity", 256)}
116+
p.target_modality = ("image:identity", 256)
117117
p.batch_size_multiplier = 256
118118
p.max_expected_batch_size_per_shard = 4
119119
p.input_space_id = 1
@@ -236,7 +236,7 @@ def feature_encoders(self, data_dir):
236236

237237
def hparams(self, defaults, unused_model_hparams):
238238
p = defaults
239-
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
239+
p.input_modality = {"inputs": (registry.Modalities.IMAGE, 256)}
240240
vocab_size = self._encoders["targets"].vocab_size
241241
p.target_modality = (registry.Modalities.SYMBOL, vocab_size)
242242
p.batch_size_multiplier = 256
@@ -286,7 +286,7 @@ def generator(self, data_dir, tmp_dir, is_training):
286286

287287
def hparams(self, defaults, unused_model_hparams):
288288
p = defaults
289-
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
289+
p.input_modality = {"inputs": (registry.Modalities.IMAGE, 256)}
290290
p.target_modality = (registry.Modalities.CLASS_LABEL,
291291
self.num_classes)
292292
p.batch_size_multiplier = 4 if self.is_small else 256
@@ -432,8 +432,8 @@ def preprocess_example(self, example, unused_mode, unused_hparams):
432432

433433
def hparams(self, defaults, unused_model_hparams):
434434
p = defaults
435-
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
436-
p.target_modality = ("image:identity_no_pad", None)
435+
p.input_modality = {"inputs": ("image:identity", 256)}
436+
p.target_modality = ("image:identity", 256)
437437
p.batch_size_multiplier = 256
438438
p.max_expected_batch_size_per_shard = 4
439439
p.input_space_id = 1
@@ -718,8 +718,8 @@ def preprocess_example(self, example, unused_mode, unused_hparams):
718718

719719
def hparams(self, defaults, unused_model_hparams):
720720
p = defaults
721-
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
722-
p.target_modality = ("image:identity_no_pad", None)
721+
p.input_modality = {"inputs": ("image:identity", 256)}
722+
p.target_modality = ("image:identity", 256)
723723
p.batch_size_multiplier = 256
724724
p.max_expected_batch_size_per_shard = 4
725725
p.input_space_id = 1
@@ -863,7 +863,7 @@ def feature_encoders(self, data_dir):
863863

864864
def hparams(self, defaults, unused_model_hparams):
865865
p = defaults
866-
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
866+
p.input_modality = {"inputs": (registry.Modalities.IMAGE, 256)}
867867
encoder = self._encoders["targets"]
868868
p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size)
869869
p.batch_size_multiplier = 256

tensor2tensor/layers/modalities.py

Lines changed: 33 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,22 @@ class SymbolModality(modality.Modality):
4545
def name(self):
4646
return "symbol_modality_%d_%d" % (self._vocab_size, self._body_input_depth)
4747

48-
@property
49-
def top_dimensionality(self):
50-
return self._vocab_size
51-
5248
@property
5349
def top_is_pointwise(self):
5450
return True
5551

52+
@property
53+
def weights_fn(self):
54+
weights_fn = common_layers.weights_nonzero
55+
56+
hp = self._model_hparams
57+
if hp and hp.prepend_mode != "none":
58+
assert (hp.prepend_mode == "prepend_inputs_masked_attention" or
59+
hp.prepend_mode == "prepend_inputs_full_attention")
60+
weights_fn = common_layers.weights_prepend_inputs_to_targets
61+
62+
return weights_fn
63+
5664
def _get_weights(self, hidden_dim=None):
5765
"""Create or get concatenated embedding or softmax variable.
5866
@@ -151,7 +159,7 @@ def top(self, body_output, _):
151159
class CTCSymbolModality(SymbolModality):
152160
"""SymbolModality that uses CTC loss."""
153161

154-
def loss(self, logits, targets, weights_fn=common_layers.weights_nonzero):
162+
def loss(self, logits, targets):
155163
"""Compute the CTC loss."""
156164
with tf.name_scope("ctc_loss", [logits, targets]):
157165
# For CTC we assume targets are 1d, [batch, length, 1, 1] here.
@@ -172,21 +180,14 @@ def loss(self, logits, targets, weights_fn=common_layers.weights_nonzero):
172180
time_major=False,
173181
preprocess_collapse_repeated=False,
174182
ctc_merge_repeated=False)
175-
weights = weights_fn(targets)
183+
weights = self.targets_weights_fn(targets)
176184
return tf.reduce_sum(xent), tf.reduce_sum(weights)
177185

178186

179187
@registry.register_image_modality("default")
180188
class ImageModality(modality.Modality):
181189
"""Modality for images."""
182-
183-
def __init__(self, model_hparams, vocab_size):
184-
super(ImageModality, self).__init__(model_hparams, vocab_size)
185-
self._channels = 3
186-
187-
@property
188-
def top_dimensionality(self):
189-
return 256
190+
NUM_CHANNELS = 3
190191

191192
def bottom(self, inputs):
192193
with tf.variable_scope(self.name):
@@ -217,7 +218,7 @@ def top(self, body_output, _):
217218
common_layers.shape_dim(body_output, i) for i in range(3)
218219
]
219220
dim = body_output.get_shape().as_list()[-1] // 3
220-
reshape_shape.extend([self._channels, dim])
221+
reshape_shape.extend([self.NUM_CHANNELS, dim])
221222

222223
out = tf.reshape(body_output, reshape_shape)
223224
res = tf.layers.dense(out, self.top_dimensionality)
@@ -226,21 +227,11 @@ def top(self, body_output, _):
226227
tf.summary.image("result", res_argmax, max_outputs=1)
227228
return res
228229

229-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
230-
# Call the default implementation, but weight 1.0 on 0s by default.
231-
# (Since we're processing images and so have no padding and some pixel 0s.)
232-
return super(ImageModality, self).loss(
233-
top_out, targets, weights_fn=weights_fn)
234-
235230

236231
@registry.register_image_modality("image_identity_compress")
237232
class ImageIdentityCompressModality(modality.Modality):
238233
"""Modality for images used in generation."""
239234

240-
@property
241-
def top_dimensionality(self):
242-
return 256
243-
244235
def bottom_compress(self, inputs, name="bottom"):
245236
"""Transform input from data space to model space.
246237
@@ -296,12 +287,6 @@ def top(self, body_output, _):
296287
channels, self.top_dimensionality])
297288
return x
298289

299-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
300-
# Call the default implementation, but weight 1.0 on 0s by default.
301-
# (Since we're processing images and so have no padding and some pixel 0s.)
302-
return super(ImageIdentityCompressModality, self).loss(
303-
top_out, targets, weights_fn=weights_fn)
304-
305290

306291
@registry.register_audio_modality("default")
307292
class AudioModality(modality.Modality):
@@ -399,10 +384,6 @@ def name(self):
399384
return "class_label_modality_%d_%d" % (self._vocab_size,
400385
self._body_input_depth)
401386

402-
@property
403-
def top_dimensionality(self):
404-
return self._vocab_size
405-
406387
def bottom(self, x):
407388
with tf.variable_scope(self.name):
408389
return common_layers.embedding(
@@ -434,12 +415,6 @@ def top(self, body_output, _):
434415
res = tf.layers.dense(x, self._vocab_size)
435416
return tf.expand_dims(res, 3)
436417

437-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
438-
# Call the default implementation, but weight 1.0 on 0s by default.
439-
# (Since we're processing images and so have no padding and some pixel 0s.)
440-
return super(ClassLabelModality, self).loss(
441-
top_out, targets, weights_fn=weights_fn)
442-
443418

444419
@registry.register_generic_modality("default")
445420
@registry.register_audio_modality("identity")
@@ -450,10 +425,6 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
450425
class IdentityModality(modality.Modality):
451426
"""Does nothing."""
452427

453-
@property
454-
def targets_dimensionality(self):
455-
return self._vocab_size
456-
457428
def bottom(self, x):
458429
return tf.to_float(x)
459430

@@ -476,7 +447,7 @@ def top(self, body_output, _):
476447
with tf.variable_scope("real"):
477448
return tf.layers.dense(body_output, self._vocab_size)
478449

479-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
450+
def loss(self, top_out, targets):
480451
raise NotImplementedError()
481452

482453

@@ -485,70 +456,35 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
485456
class RealL2LossModality(RealModality):
486457
"""Modality for real (i.e. float) vectors with L2 (Gaussian) loss."""
487458

488-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
459+
def loss(self, top_out, targets):
489460
predictions = top_out
490461
with tf.name_scope("l2"):
491-
weights = weights_fn(targets)
462+
weights = self.targets_weights_fn(targets)
492463
l2 = tf.pow(predictions - targets, 2)
493464
return tf.reduce_sum(l2 * weights), tf.reduce_sum(weights)
494465

495466

496467
@registry.register_real_modality("log_poisson_loss")
497-
class RealLogPoissonLossModality(RealL2LossModality):
498-
"""Modality for real (i.e. float) vectors with log Poisson regression loss.
499-
"""
500-
501-
def bottom(self, x):
502-
return x
468+
class RealLogPoissonLossModality(RealModality):
469+
"""Modality for real (i.e. float) vectors with log Poisson regression loss."""
503470

504-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
471+
def loss(self, top_out, targets):
505472
predictions = top_out
506473
with tf.name_scope("log_possion"):
507-
weights = weights_fn(targets)
474+
weights = self.targets_weights_fn(targets)
508475

509476
lp_loss = tf.nn.log_poisson_loss(targets, predictions)
510477
return tf.reduce_sum(lp_loss * weights), tf.reduce_sum(weights)
511478

512479

513-
@registry.register_image_modality("identity_no_pad")
514-
class IdentityModalityNoPad(modality.Modality):
515-
"""Does nothing except making sure that there is no padding in cross-ent."""
516-
517-
@property
518-
def top_dimensionality(self):
519-
return 256
520-
521-
@property
522-
def targets_dimensionality(self):
523-
return self._vocab_size
524-
525-
def bottom(self, x):
526-
return tf.to_float(x)
527-
528-
def top(self, body_output, _):
529-
return body_output
530-
531-
def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
532-
# Call the default implementation, but weight 1.0 on 0s by default.
533-
# (Since we're processing images and so have no padding and some pixel 0s.)
534-
return super(IdentityModalityNoPad, self).loss(
535-
top_out, targets, weights_fn=weights_fn)
536-
537-
538-
@registry.register_image_modality("no_loss")
539-
class NoLossModality(modality.Modality):
540-
"""Does nothing to the input and returns no loss."""
541-
542-
@property
543-
def targets_dimensionality(self):
544-
return self._vocab_size
545-
546-
def bottom(self, x):
547-
return tf.to_float(x)
548-
549-
def top(self, body_output, _):
550-
return body_output
480+
@registry.register_generic_modality("zero_loss")
481+
@registry.register_audio_modality("zero_loss")
482+
@registry.register_image_modality("zero_loss")
483+
@registry.register_symbol_modality("zero_loss")
484+
@registry.register_class_label_modality("zero_loss")
485+
@registry.register_real_modality("zero_loss")
486+
class IdentityZeroLossModality(IdentityModality):
487+
"""Identity with 0 loss."""
551488

552-
def loss_sharded(self, sharded_top_out, sharded_targets, data_parallelism):
553-
"""Return nothing."""
554-
return tf.constant(0.0, tf.float32)
489+
def loss(self, top_out, targets):
490+
return tf.constant(0., tf.float32), tf.constant(0., tf.float32)

tensor2tensor/models/vanilla_gan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ def vanilla_gan():
146146

147147
hparams = common_hparams.basic_params1()
148148

149-
hparams.input_modalities = "image:no_loss"
150-
hparams.target_modality = "image:no_loss"
149+
hparams.input_modalities = "inputs:image:zero_loss"
150+
hparams.target_modality = "image:zero_loss"
151151

152152
hparams.batch_size = 2048 # 3136
153153
hparams.label_smoothing = 0.0

tensor2tensor/tpu/tpu_trainer_lib.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import six
2727

28-
from tensor2tensor.layers import common_layers
2928
from tensor2tensor.utils import data_reader
3029
from tensor2tensor.utils import metrics
3130
from tensor2tensor.utils import optimize
@@ -192,7 +191,7 @@ def model_fn(features, labels, mode, params, config):
192191
problem = hp.problem_instances[0]
193192

194193
if use_tpu:
195-
eval_metrics_fn = create_eval_metrics_fn(problem)
194+
eval_metrics_fn = create_eval_metrics_fn(problem, hparams)
196195
_remove_summaries()
197196
return tf.contrib.tpu.TPUEstimatorSpec(
198197
mode,
@@ -245,14 +244,18 @@ def model_fn(features, labels, mode, params, config):
245244
])
246245

247246

248-
def create_eval_metrics_fn(problem):
247+
def create_eval_metrics_fn(problem, hparams):
249248
"""Create the metrics_fn that TPUEstimatorSpec expects."""
250249

250+
tm = problem.get_hparams().target_modality
251+
if isinstance(tm, tuple):
252+
tm = registry.create_modality(tm, hparams)
253+
weights_fn = tm.weights_fn
254+
251255
def make_metric_fn(metric_fn):
252256

253257
def wrapped_metric_fn(logits, labels):
254-
num, den = metric_fn(
255-
logits, labels, weights_fn=common_layers.weights_nonzero)
258+
num, den = metric_fn(logits, labels, weights_fn=weights_fn)
256259
return tf.metrics.mean(num, den)
257260

258261
return wrapped_metric_fn

tensor2tensor/utils/metrics.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from tensor2tensor.layers import common_layers
2626
from tensor2tensor.utils import bleu_hook
27+
from tensor2tensor.utils import registry
2728
from tensor2tensor.utils import rouge
2829

2930
import tensorflow as tf
@@ -284,7 +285,7 @@ def problem_metric_fn(predictions, features):
284285
# "features".
285286
kwargs = {}
286287
args, _, keywords, _ = inspect.getargspec(metric_fn)
287-
if "features" in args or keywords:
288+
if ("features" in args) or keywords:
288289
kwargs["features"] = features
289290

290291
def wrapped_metric_fn():
@@ -308,28 +309,21 @@ def wrapped_metric_fn():
308309
metrics,
309310
METRICS_FNS.keys()))
310311

311-
class_output = "image" in problem_name and "coco" not in problem_name
312-
real_output = "gene_expression" in problem_name
313-
if model_hparams.prepend_mode != "none":
314-
assert (model_hparams.prepend_mode == "prepend_inputs_masked_attention" or
315-
model_hparams.prepend_mode == "prepend_inputs_full_attention")
316-
assert not class_output
317-
weights_fn = common_layers.weights_prepend_inputs_to_targets
318-
elif class_output or real_output:
319-
weights_fn = common_layers.weights_all
320-
else:
321-
weights_fn = common_layers.weights_nonzero
322-
323312
def image_wrapped_metric_fn(predictions,
324313
labels,
325314
weights_fn=common_layers.weights_nonzero):
326315
_, _ = labels, weights_fn
327316
return metric_fn(predictions, model_hparams)
328317

318+
tm = problem_instance.get_hparams().target_modality
319+
if isinstance(tm, tuple):
320+
tm = registry.create_modality(tm, model_hparams)
321+
weights_fn = tm.weights_fn
322+
329323
for metric in metrics:
330324
metric_fn = METRICS_FNS[metric]
331325
metric_name = "metrics-%s/%s" % (problem_name, metric)
332-
if "image" in metric:
326+
if metric == Metrics.IMAGE_SUMMARY:
333327
eval_metrics[metric_name] = image_wrapped_metric_fn
334328
else:
335329
problem_metric_fn = make_problem_specific_metric_fn(

0 commit comments

Comments
 (0)