forked from tensorflow/tensor2tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodality.py
More file actions
164 lines (134 loc) · 5.82 KB
/
modality.py
File metadata and controls
164 lines (134 loc) · 5.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modality base class - defines the bottom and top of the model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
# Dependency imports
from tensor2tensor.layers import common_layers
import tensorflow as tf
class Modality(object):
"""Abstract Modality class for data transformations.
An abstract class representing modalities for transforming data to a space
interpretable by T2T models. It has 4 functions:
* bottom: called on inputs entering the model.
* targets_bottom: called on targets entering the model (e.g., the decoder).
* top: called on model outputs to generate predictions (e.g., logits).
* loss: called on predictions (outputs of top) and targets.
For example, think about a modality for images:
* `bottom` represents the part of the model applied to an incoming image,
e.g., an entry flow of a convolutional network.
* `top` represents the top part of a model that is generating images, e.g., a
PixelCNN network.
* `targets_bottom` represents the auto-regressive part of the network. It is
applied to the already-generated part of an image, which is given to the
decoder to generate the next part. In some cases, e.g., for text, it is the
same as the `bottom` function, and that is the default we use. But, e.g.,
for images, a different function might be needed to regress properly.
* `loss` would compare the generated image to the target image and score it.
All the functions have simple and sharded versions. A sub-class only needs to
implement the simple version, the default sharding will be used then.
"""
def __init__(self, model_hparams, vocab_size=None):
self._model_hparams = model_hparams
self._vocab_size = vocab_size
@property
def name(self):
camelcase_name = type(self).__name__ # DeCamelCase for TF readability.
return re.sub("([A-Z]+)", r"_\1", camelcase_name).lower()[1:]
@property
def top_dimensionality(self):
"""Integer, the last dimension of the predictions (vocab size)."""
raise NotImplementedError("Abstract Method")
@property
def _body_input_depth(self):
return self._model_hparams.hidden_size
def bottom(self, x):
"""Transform one shard of input.
Args:
x: An int32 Tensor with shape [batch, p0, p1, input_channels]
Returns:
A float32 Tensor with shape [batch, p0, p1, body_input_depth]
"""
raise NotImplementedError("Abstract Method")
def bottom_sharded(self, xs, data_parallelism):
"""Transform the inputs.
Args:
xs: A list of num_datashards Tensors (one per shard)
each with shape [batch, p0, p1, depth]
data_parallelism: a expert_utils.Parallelism object
Returns:
shaded_body_input: A list of num_datashards Tensors, each with shape
[batch, p0, p1, body_input_depth].
"""
return data_parallelism(self.bottom, xs)
def targets_bottom(self, x):
"""Transform one shard of targets.
Args:
x: An int32 Tensor with shape [batch, p0, p1, target_channels]
Returns:
A float32 Tensor with shape [batch, p0, p1, body_input_depth]
"""
with tf.variable_scope("targets_bottom"):
return self.bottom(x)
def targets_bottom_sharded(self, xs, data_parallelism):
"""Transform the targets.
Args:
xs: A list of num_datashards Tensors (one per shard)
each with shape [batch, p0, p1, target_channels]
data_parallelism: a expert_utils.Parallelism object
Returns:
shaded_body_input: A list of num_datashards Tensors, each with shape
[batch, p0, p1, body_input_depth].
"""
return data_parallelism(self.targets_bottom, xs)
def top(self, body_output, targets):
"""Generate predictions/logits for one shard of output.
Most classes will override this function.
Args:
body_output: A Tensor with shape [batch, p0, p1, body_output_depth]
targets: A Tensor with shape [batch, p0, p1, targets_channels,
top_dimensionality]
Returns:
A Tensor of class logits.
"""
raise NotImplementedError("Abstract Method")
def top_sharded(self, sharded_body_output, sharded_targets, data_parallelism):
"""Generate predictions/logits for all shards.
Classes with cross-shard interaction will override this function.
Args:
sharded_body_output: A list of Tensors.
sharded_targets: A list of Tensors.
data_parallelism: a expert_utils.Parallelism object.
Returns:
sharded_logits: A list of Tensors.
"""
return data_parallelism(self.top, sharded_body_output, sharded_targets)
def loss(self, top_out, targets, weights_fn=common_layers.weights_nonzero):
"""Compute loss numerator and denominator for one shard of output."""
logits = top_out
return common_layers.padded_cross_entropy(
logits,
targets,
self._model_hparams.label_smoothing,
weights_fn=weights_fn)
def loss_sharded(self, sharded_top_out, sharded_targets, data_parallelism):
"""Compute loss for all shards."""
sharded_loss_num, sharded_loss_den = data_parallelism(
self.loss, sharded_top_out, sharded_targets)
loss = tf.add_n(sharded_loss_num) / tf.maximum(1.0,
tf.add_n(sharded_loss_den))
return loss