Skip to content

Commit 3950b40

Browse files
author
Ryan Sepassi
committed
Fix Problem.filepattern to include PREDICT
PiperOrigin-RevId: 170415717
1 parent fb858cb commit 3950b40

3 files changed

Lines changed: 29 additions & 34 deletions

File tree

tensor2tensor/data_generators/problem.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,26 @@ def test_filepaths(self, data_dir, num_shards, shuffled):
235235
num_shards)
236236

237237
def filepattern(self, data_dir, mode):
238-
"""Get filepattern for data files for mode."""
238+
"""Get filepattern for data files for mode.
239+
240+
Matches mode to a suffix.
241+
* TRAIN: train
242+
* EVAL: dev
243+
* PREDICT: dev
244+
* test: test
245+
246+
Args:
247+
data_dir: str, data directory.
248+
mode: tf.estimator.ModeKeys or "test".
249+
250+
Returns:
251+
filepattern str
252+
"""
239253
path = os.path.join(data_dir, self.dataset_filename())
240254

241255
if mode == tf.estimator.ModeKeys.TRAIN:
242256
suffix = "train"
243-
elif mode == tf.estimator.ModeKeys.EVAL:
257+
elif mode in [tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT]:
244258
suffix = "dev"
245259
else:
246260
assert mode == "test"

tensor2tensor/utils/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def nth_model(n):
288288
diet_vars = [
289289
v for v in tf.global_variables() if v.dtype == dtypes.float16_ref
290290
]
291-
_log_variable_sizes(diet_vars, "Diet Variables")
291+
_log_variable_sizes(diet_vars, "Diet Varaibles")
292292

293293
# Optimize
294294
total_loss = tf.identity(total_loss, name="total_loss")

tensor2tensor/visualization/TransformerVisualization.ipynb

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
{
1616
"cell_type": "code",
1717
"execution_count": 1,
18-
"metadata": {
19-
"collapsed": true
20-
},
18+
"metadata": {},
2119
"outputs": [],
2220
"source": [
2321
"from __future__ import absolute_import\n",
@@ -36,9 +34,7 @@
3634
{
3735
"cell_type": "code",
3836
"execution_count": 2,
39-
"metadata": {
40-
"collapsed": false
41-
},
37+
"metadata": {},
4238
"outputs": [
4339
{
4440
"data": {
@@ -76,9 +72,7 @@
7672
{
7773
"cell_type": "code",
7874
"execution_count": 3,
79-
"metadata": {
80-
"collapsed": false
81-
},
75+
"metadata": {},
8276
"outputs": [
8377
{
8478
"name": "stdout",
@@ -111,7 +105,6 @@
111105
"cell_type": "code",
112106
"execution_count": 4,
113107
"metadata": {
114-
"collapsed": false,
115108
"scrolled": true
116109
},
117110
"outputs": [
@@ -183,9 +176,7 @@
183176
{
184177
"cell_type": "code",
185178
"execution_count": 6,
186-
"metadata": {
187-
"collapsed": false
188-
},
179+
"metadata": {},
189180
"outputs": [
190181
{
191182
"name": "stdout",
@@ -200,15 +191,13 @@
200191
],
201192
"source": [
202193
"spec = utils.model_builder.model_fn(MODEL, features, tf.estimator.ModeKeys.EVAL, hparams, problem_names=[PROBLEM])\n",
203-
"predictions_dict = spec.predictions"
194+
"predictions_dict = spec.predictions",
204195
]
205196
},
206197
{
207198
"cell_type": "code",
208199
"execution_count": 7,
209-
"metadata": {
210-
"collapsed": false
211-
},
200+
"metadata": {},
212201
"outputs": [
213202
{
214203
"name": "stdout",
@@ -225,7 +214,7 @@
225214
"source": [
226215
"with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n",
227216
" spec = utils.model_builder.model_fn(MODEL, features, tf.estimator.ModeKeys.PREDICT, hparams, problem_names=[PROBLEM])\n",
228-
" beam_out = spec.predictions['outputs']"
217+
" beam_out = spec.predictions['outputs']",
229218
]
230219
},
231220
{
@@ -238,9 +227,7 @@
238227
{
239228
"cell_type": "code",
240229
"execution_count": 8,
241-
"metadata": {
242-
"collapsed": false
243-
},
230+
"metadata": {},
244231
"outputs": [
245232
{
246233
"name": "stdout",
@@ -320,7 +307,6 @@
320307
"cell_type": "code",
321308
"execution_count": 10,
322309
"metadata": {
323-
"collapsed": false,
324310
"scrolled": false
325311
},
326312
"outputs": [
@@ -367,9 +353,7 @@
367353
{
368354
"cell_type": "code",
369355
"execution_count": 12,
370-
"metadata": {
371-
"collapsed": false
372-
},
356+
"metadata": {},
373357
"outputs": [
374358
{
375359
"name": "stdout",
@@ -408,9 +392,7 @@
408392
{
409393
"cell_type": "code",
410394
"execution_count": 14,
411-
"metadata": {
412-
"collapsed": false
413-
},
395+
"metadata": {},
414396
"outputs": [
415397
{
416398
"data": {
@@ -458,7 +440,6 @@
458440
"cell_type": "code",
459441
"execution_count": null,
460442
"metadata": {
461-
"collapsed": true,
462443
"scrolled": true
463444
},
464445
"outputs": [],
@@ -486,9 +467,9 @@
486467
"name": "python",
487468
"nbconvert_exporter": "python",
488469
"pygments_lexer": "ipython2",
489-
"version": "2.7.13"
470+
"version": "2.7.12"
490471
}
491472
},
492473
"nbformat": 4,
493474
"nbformat_minor": 2
494-
}
475+
}

0 commit comments

Comments
 (0)