Skip to content

Commit 04b4f47

Browse files
author
Ryan Sepassi
committed
Update documentation for adding new Problems
PiperOrigin-RevId: 162242293
1 parent 963730e commit 04b4f47

13 files changed

Lines changed: 111 additions & 184 deletions

File tree

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ python -c "from tensor2tensor.models.transformer import Transformer"
153153
specification.
154154
* Support for multi-GPU machines and synchronous (1 master, many workers) and
155155
asynchrounous (independent workers synchronizing through a parameter server)
156-
distributed training.
156+
[distributed training](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/docs/distributed_training.md).
157157
* Easily swap amongst datasets and models by command-line flag with the data
158158
generation script `t2t-datagen` and the training script `t2t-trainer`.
159159

@@ -173,8 +173,10 @@ and many common sequence datasets are already available for generation and use.
173173

174174
**Problems** define training-time hyperparameters for the dataset and task,
175175
mainly by setting input and output **modalities** (e.g. symbol, image, audio,
176-
label) and vocabularies, if applicable. All problems are defined in
177-
[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py).
176+
label) and vocabularies, if applicable. All problems are defined either in
177+
[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py)
178+
or are registered with `@registry.register_problem` (run `t2t-datagen` to see
179+
the list of all available problems).
178180
**Modalities**, defined in
179181
[`modality.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/modality.py),
180182
abstract away the input and output data types so that **models** may deal with
@@ -222,7 +224,7 @@ enables easily adding new ones and easily swapping amongst them by command-line
222224
flag. You can add your own components without editing the T2T codebase by
223225
specifying the `--t2t_usr_dir` flag in `t2t-trainer`.
224226

225-
You can currently do so for models, hyperparameter sets, and modalities. Please
227+
You can do so for models, hyperparameter sets, modalities, and problems. Please
226228
do submit a pull request if your component might be useful to others.
227229

228230
Here's an example with a new hyperparameter set:
@@ -253,9 +255,18 @@ You'll see under the registered HParams your
253255
`transformer_my_very_own_hparams_set`, which you can directly use on the command
254256
line with the `--hparams_set` flag.
255257

258+
`t2t-datagen` also supports the `--t2t_usr_dir` flag for `Problem`
259+
registrations.
260+
256261
## Adding a dataset
257262

258-
See the [data generators
263+
To add a new dataset, subclass
264+
[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
265+
and register it with `@registry.register_problem`. See
266+
[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
267+
for an example.
268+
269+
Also see the [data generators
259270
README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md).
260271

261272
---
File renamed without changes.

tensor2tensor/bin/t2t-datagen

100755100644
Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ from tensor2tensor.data_generators import wiki
4848
from tensor2tensor.data_generators import wmt
4949
from tensor2tensor.data_generators import wsj_parsing
5050
from tensor2tensor.utils import registry
51-
from tensor2tensor.utils import usr_dir
5251

5352
import tensorflow as tf
5453

@@ -65,13 +64,6 @@ flags.DEFINE_integer("max_cases", 0,
6564
"Maximum number of cases to generate (unbounded if 0).")
6665
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6766

68-
flags.DEFINE_string("t2t_usr_dir", "",
69-
"Path to a Python module that will be imported. The "
70-
"__init__.py file should include the necessary imports. "
71-
"The imported files should contain registrations, "
72-
"e.g. @registry.register_model calls, that will then be "
73-
"available to the t2t-datagen.")
74-
7567
# Mapping from problems that we can generate data for to their generators.
7668
# pylint: disable=g-long-lambda
7769
_SUPPORTED_PROBLEM_GENERATORS = {
@@ -281,7 +273,6 @@ def set_random_seed():
281273

282274
def main(_):
283275
tf.logging.set_verbosity(tf.logging.INFO)
284-
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
285276

286277
# Calculate the list of problems to generate.
287278
problems = sorted(
@@ -365,7 +356,7 @@ def generate_data_for_problem(problem):
365356

366357
def generate_data_for_registered_problem(problem_name):
367358
problem = registry.problem(problem_name)
368-
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir, FLAGS.num_shards)
359+
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
369360

370361

371362
if __name__ == "__main__":

tensor2tensor/bin/t2t-trainer

100755100644
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import sys
3636
# Dependency imports
3737

3838
from tensor2tensor.utils import trainer_utils as utils
39-
from tensor2tensor.utils import usr_dir
39+
4040
import tensorflow as tf
4141

4242
flags = tf.flags
@@ -49,9 +49,25 @@ flags.DEFINE_string("t2t_usr_dir", "",
4949
"e.g. @registry.register_model calls, that will then be "
5050
"available to the t2t-trainer.")
5151

52+
53+
def import_usr_dir():
54+
"""Import module at FLAGS.t2t_usr_dir, if provided."""
55+
if not FLAGS.t2t_usr_dir:
56+
return
57+
dir_path = os.path.expanduser(FLAGS.t2t_usr_dir)
58+
if dir_path[-1] == "/":
59+
dir_path = dir_path[:-1]
60+
containing_dir, module_name = os.path.split(dir_path)
61+
tf.logging.info("Importing user module %s from path %s", module_name,
62+
containing_dir)
63+
sys.path.insert(0, containing_dir)
64+
importlib.import_module(module_name)
65+
sys.path.pop(0)
66+
67+
5268
def main(_):
5369
tf.logging.set_verbosity(tf.logging.INFO)
54-
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
70+
import_usr_dir()
5571
utils.log_registry()
5672
utils.validate_flags()
5773
utils.run(
Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
# Data generators for T2T models.
1+
# T2T Problems.
22

3-
This directory contains data generators for a number of problems. We use a
4-
naming scheme for the problems, they have names of the form
3+
This directory contains `Problem` specifications for a number of problems. We
4+
use a naming scheme for the problems, they have names of the form
55
`[task-family]_[task]_[specifics]`. Data for all currently supported problems
66
can be generated by calling the main generator binary (`t2t-datagen`). For
77
example:
@@ -20,53 +20,51 @@ All tasks produce TFRecord files of `tensorflow.Example` protocol buffers.
2020

2121
## Adding a new problem
2222

23-
1. Implement and register a Python generator for the dataset
24-
1. Add a problem specification to `problem_hparams.py` specifying input and
25-
output modalities
23+
To add a new problem, subclass
24+
[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)
25+
and register it with `@registry.register_problem`. See
26+
[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
27+
for an example.
2628

27-
To add a new problem, you first need to create python generators for training
28-
and development data for the problem. The python generators should yield
29-
dictionaries with string keys and values being lists of {int, float, str}.
30-
Here is a very simple generator for a data-set where inputs are lists of 1s with
31-
length upto 100 and targets are lists of length 1 with an integer denoting the
32-
length of the input list.
29+
`Problem`s support data generation, training, and decoding.
30+
31+
Data generation is handles by `Problem.generate_data` which should produce 2
32+
datasets, training and dev, which should be named according to
33+
`Problem.training_filepaths` and `Problem.dev_filepaths`.
34+
`Problem.generate_data` should also produce any other files that may be required
35+
for training/decoding, e.g. a vocabulary file.
36+
37+
A particularly easy way to implement `Problem.generate_data` for your dataset is
38+
to create 2 Python generators, one for the training data and another for the
39+
dev data, and pass them to `generator_utils.generate_dataset_and_shuffle`. See
40+
[`WMTEnDeTokens8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
41+
for an example of usage.
42+
43+
The generators should yield dictionaries with string keys and values being lists
44+
of {int, float, str}. Here is a very simple generator for a data-set where
45+
inputs are lists of 2s with length upto 100 and targets are lists of length 1
46+
with an integer denoting the length of the input list.
3347

3448
```
3549
def length_generator(nbr_cases):
3650
for _ in xrange(nbr_cases):
3751
length = np.random.randint(100) + 1
38-
yield {"inputs": [1] * length, "targets": [length]}
52+
yield {"inputs": [2] * length, "targets": [length]}
3953
```
4054

41-
Note that our data reader uses 0 for padding, so it is a good idea to never
42-
generate 0s, except if all your examples have the same size (in which case
43-
they'll never be padded anyway) or if you're doing padding on your own (in which
44-
case please use 0s for padding). When adding the python generator function,
45-
please also add unit tests to check if the code runs.
55+
Note that our data reader uses 0 for padding and other parts of the code assume
56+
end-of-string (EOS) is 1, so it is a good idea to never generate 0s or 1s,
57+
except if all your examples have the same size (in which case they'll never be
58+
padded anyway) or if you're doing padding on your own (in which case please use
59+
0s for padding). When adding the python generator function, please also add unit
60+
tests to check if the code runs.
4661

4762
The generator can do arbitrary setup before beginning to yield examples - for
4863
example, downloading data, generating vocabulary files, etc.
4964

5065
Some examples:
5166

52-
* [Algorithmic generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py)
67+
* [Algorithmic problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py)
5368
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic_test.py)
54-
* [WMT generators](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
69+
* [WMT problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py)
5570
and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt_test.py)
56-
57-
When your python generator is ready and tested, add it to the
58-
`_SUPPORTED_PROBLEM_GENERATORS` dictionary in the
59-
[data
60-
generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen).
61-
The keys are problem names, and the values are pairs of (training-set-generator
62-
function, dev-set-generator function). For the generator above, one could add
63-
the following lines:
64-
65-
```
66-
"algorithmic_length_upto100":
67-
(lambda: algorithmic.length_generator(10000),
68-
lambda: algorithmic.length_generator(1000)),
69-
```
70-
71-
Note the lambdas above: we don't want to call the generators too early.
72-

tensor2tensor/data_generators/algorithmic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ class AlgorithmicIdentityBinary40(problem.Problem):
3636
def num_symbols(self):
3737
return 2
3838

39-
def generate_data(self, data_dir, _, num_shards=100):
39+
def generate_data(self, data_dir, _):
4040
utils.generate_dataset_and_shuffle(
4141
identity_generator(self.num_symbols, 40, 100000),
42-
self.training_filepaths(data_dir, num_shards, shuffled=True),
42+
self.training_filepaths(data_dir, 100, shuffled=True),
4343
identity_generator(self.num_symbols, 400, 10000),
4444
self.dev_filepaths(data_dir, 1, shuffled=True),
4545
shuffle=False)

tensor2tensor/data_generators/generator_utils.py

100755100644
Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,6 @@ def gunzip_file(gz_path, new_path):
244244
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
245245
["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"]
246246
],
247-
# Macedonian-English
248-
[
249-
"https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long
250-
["train.mk", "train.en"]
251-
],
252247
]
253248

254249

@@ -329,19 +324,18 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename,
329324
return vocab
330325

331326
# Use Tokenizer to count the word occurrences.
332-
token_counts = defaultdict(int)
333327
filepath = os.path.join(tmp_dir, source_filename)
334328
with tf.gfile.GFile(filepath, mode="r") as source_file:
335329
for line in source_file:
336330
line = line.strip()
337331
if line and "\t" in line:
338332
parts = line.split("\t", maxsplit=1)
339333
part = parts[index].strip()
340-
for tok in tokenizer.encode(text_encoder.native_to_unicode(part)):
341-
token_counts[tok] += 1
334+
_ = tokenizer.encode(text_encoder.native_to_unicode(part))
342335

343336
vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
344-
vocab_size, token_counts, 1, 1e3)
337+
vocab_size, tokenizer.token_counts, 1,
338+
min(1e3, vocab_size + text_encoder.NUM_RESERVED_TOKENS))
345339
vocab.store_to_file(vocab_filepath)
346340
return vocab
347341

tensor2tensor/data_generators/problem.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ class SpaceID(object):
6767
ICE_TOK = 18
6868
# Icelandic parse tokens
6969
ICE_PARSE_TOK = 19
70-
# Macedonian tokens
71-
MK_TOK = 20
7270

7371

7472
class Problem(object):
@@ -113,7 +111,7 @@ class Problem(object):
113111
# BEGIN SUBCLASS INTERFACE
114112
# ============================================================================
115113

116-
def generate_data(self, data_dir, tmp_dir, num_shards=100):
114+
def generate_data(self, data_dir, tmp_dir):
117115
raise NotImplementedError()
118116

119117
def hparams(self, defaults, model_hparams):

tensor2tensor/data_generators/text_encoder.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from __future__ import print_function
2525

2626
from collections import defaultdict
27-
import re
2827

2928
# Dependency imports
3029

@@ -226,7 +225,6 @@ class SubwordTextEncoder(TextEncoder):
226225

227226
def __init__(self, filename=None):
228227
"""Initialize and read from a file, if provided."""
229-
self._alphabet = set()
230228
if filename is not None:
231229
self._load_from_file(filename)
232230
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
@@ -505,12 +503,6 @@ def _escape_token(self, token):
505503
ret += u"\\%d;" % ord(c)
506504
return ret
507505

508-
# Regular expression for unescaping token strings
509-
# '\u' is converted to '_'
510-
# '\\' is converted to '\'
511-
# '\213;' is converted to unichr(213)
512-
_UNESCAPE_REGEX = re.compile(u'|'.join([r"\\u", r"\\\\", r"\\([0-9]+);"]))
513-
514506
def _unescape_token(self, escaped_token):
515507
"""Inverse of _escape_token().
516508
@@ -519,14 +511,32 @@ def _unescape_token(self, escaped_token):
519511
Returns:
520512
token: a unicode string
521513
"""
522-
def match(m):
523-
if m.group(1) is not None:
524-
# Convert '\213;' to unichr(213)
525-
try:
526-
return unichr(int(m.group(1)))
527-
except (ValueError, OverflowError) as _:
528-
return ""
529-
# Convert '\u' to '_' and '\\' to '\'
530-
return u"_" if m.group(0) == u"\\u" else u"\\"
531-
# Cut off the trailing underscore and apply the regex substitution
532-
return self._UNESCAPE_REGEX.sub(match, escaped_token[:-1])
514+
ret = u""
515+
escaped_token = escaped_token[:-1]
516+
pos = 0
517+
while pos < len(escaped_token):
518+
c = escaped_token[pos]
519+
if c == "\\":
520+
pos += 1
521+
if pos >= len(escaped_token):
522+
break
523+
c = escaped_token[pos]
524+
if c == u"u":
525+
ret += u"_"
526+
pos += 1
527+
elif c == "\\":
528+
ret += u"\\"
529+
pos += 1
530+
else:
531+
semicolon_pos = escaped_token.find(u";", pos)
532+
if semicolon_pos == -1:
533+
continue
534+
try:
535+
ret += unichr(int(escaped_token[pos:semicolon_pos]))
536+
pos = semicolon_pos + 1
537+
except (ValueError, OverflowError) as _:
538+
pass
539+
else:
540+
ret += c
541+
pos += 1
542+
return ret

tensor2tensor/data_generators/tokenizer_test.py

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
# Copyright 2017 The Tensor2Tensor Authors.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
1312
# See the License for the specific language governing permissions and
1413
# limitations under the License.
1514

15+
# coding=utf-8
1616
"""Tests for tensor2tensor.data_generators.tokenizer."""
1717

1818
from __future__ import absolute_import

0 commit comments

Comments
 (0)