diff --git a/README.md b/README.md index af9778725..0e97770ba 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)](https://travis-ci.org/tensorflow/tensor2tensor) [T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible library and binaries for supervised learning with TensorFlow and with support @@ -123,8 +124,7 @@ t2t-decoder \ --model=$MODEL \ --hparams_set=$HPARAMS \ --output_dir=$TRAIN_DIR \ - --decode_beam_size=$BEAM_SIZE \ - --decode_alpha=$ALPHA \ + --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \ --decode_from_file=$DECODE_FILE cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes diff --git a/docs/example_life.md b/docs/example_life.md index 2983f5077..f3b18a817 100644 --- a/docs/example_life.md +++ b/docs/example_life.md @@ -9,26 +9,189 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) -This document show how a training example passes through the T2T pipeline, -and how all its parts are connected to work together. +This doc explains how a training example flows through T2T, from data generation +to training, evaluation, and decoding. It points out the various hooks available +in the `Problem` and `T2TModel` classes and gives an overview of the T2T code +(key functions, files, hyperparameters, etc.). -## The Life of an Example +Some key files and their functions: -A training example passes the following stages in T2T: -* raw input (text from command line or file) -* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s -* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches. -* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`. -* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542) -* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`. -* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen. +* [`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py): + Constructs and runs all the main components of the system (the `Problem`, + the `HParams`, the `Estimator`, the `Experiment`, the `input_fn`s and + `model_fn`). +* [`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py): + `basic_params1` serves as the base for all model hyperparameters. Registered + model hparams functions always start with this default set of + hyperparameters. +* [`problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py): + Every dataset in T2T subclasses `Problem`. +* [`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py): + Every model in T2T subclasses `T2TModel`. -We go into these phases step by step below. +## Data Generation -## Feature Encoders +The `t2t-datagen` binary is the entrypoint for data generation. It simply looks +up the `Problem` specified by `--problem` and calls +`Problem.generate_data(data_dir, tmp_dir)`. -TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions. +All `Problem`s are expected to generate 2 sharded `TFRecords` files - 1 for +training and 1 for evaluation - with `tensorflow.Example` protocol buffers. The +expected names of the files are given by `Problem.{training, dev}_filepaths`. +Typically, the features in the `Example` will be `"inputs"` and `"targets"`; +however, some tasks have a different on-disk representation that is converted to +`"inputs"` and `"targets"` online in the input pipeline (e.g. image features are +typically stored with features `"image/encoded"` and `"image/format"` and the +decoding happens in the input pipeline). -## Modalities +For tasks that require a vocabulary, this is also the point at which the +vocabulary is generated and all examples are encoded. -TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets. +There are several utility functions in +[`generator_utils`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator_utils.py) +that are commonly used by `Problem`s to generate data. Several are highlighted +below: + +* `generate_dataset_and_shuffle`: given 2 generators, 1 for training and 1 for + eval, yielding dictionaries of `>`, will produce sharded and shuffled `TFRecords` files with + `tensorflow.Example` protos. +* `maybe_download`: downloads a file at a URL to the given directory and + filename (see `maybe_download_from_drive` if the URL points to Google + Drive). +* `get_or_generate_vocab_inner`: given a target vocabulary size and a + generator that yields lines or tokens from the dataset, will build a + `SubwordTextEncoder` along with a backing vocabulary file that can be used + to map input strings to lists of ids. + [`SubwordTextEncoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py) + uses word pieces and its encoding is fully invertible. + +## Data Input Pipeline + +Once the data is produced on disk, training, evaluation, and inference (if +decoding from the dataset) consume it by way of T2T input pipeline. This section +will give an overview of that pipeline with specific attention to the various +hooks in the `Problem` class and the model's `HParams` object (typically +registered in the model's file and specified by the `--hparams_set` flag). + +The entire input pipeline is implemented with the new `tf.data.Dataset` API +(previously `tf.contrib.data.Dataset`). + +The key function in the codebase for the input pipeline is +[`data_reader.input_pipeline`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/data_reader.py). +The full input function is built in +[`input_fn_builder.build_input_fn`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/input_fn_builder.py) +(which calls `data_reader.input_pipeline`). + +### Reading and decoding data + +`Problem.dataset_filename` specifies the prefix of the files on disk (they will +be suffixed with `-train` or `-dev` as well as their sharding). + +The features read from the files and their decoding is specified by +`Problem.example_reading_spec`, which returns 2 items: + +1. Dict mapping from on-disk feature name to on-disk types (`VarLenFeature` or + `FixedLenFeature`. +2. Dict mapping output feature name to decoder. This return value is optional + and is only needed for tasks whose features may require additional decoding + (e.g. images). You can find the available decoders in + `tf.contrib.slim.tfexample_decoder`. + +At this point in the input pipeline, the example is a `dict`. + +### Preprocessing + +The read `Example` now runs through `Problem.preprocess_example`, which by +default runs +[`problem.preprocess_example_common`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py), +which may truncate the inputs/targets or prepend to targets, governed by some +hyperparameters. + +### Batching + +Examples are bucketed by sequence length and then batched out of those buckets. +This significantly improves performance over a naive batching scheme for +variable length sequences because each example in a batch must be padded to +match the example with the maximum length in the batch. + +There are several hyperparameters that affect how examples are batched together: + +* `hp.batch_size`: this is the approximate total number of tokens in the batch + (i.e. for a sequence problem, long sequences will have smaller actual batch + size and short sequences will have a larger actual batch size in order to + generally have an equal number of tokens in the batch). +* `hp.max_length`: sequences with length longer than this will be dropped + during training (and also during eval if `hp.eval_drop_long_sequences` is + `True`). If not set, the maximum length of examples is set to + `hp.batch_size`. +* `hp.batch_size_multiplier`: multiplier for the maximum length +* `hp.min_length_bucket`: example length for the smallest bucket (i.e. the + smallest bucket will bucket examples up to this length). +* `hp.length_bucket_step`: controls how spaced out the length buckets are. + +## Building the Model + +At this point, the input features typically have `"inputs"` and `"targets"`, +each of which is a batched 4-D Tensor (e.g. of shape `[batch_size, +sequence_length, 1, 1]` for text input or `[batch_size, height, width, 3]` for +image input). + +A `T2TModel` is composed of transforms of the input features by `Modality`s, +then the body of the model, then transforms of the model output to predictions +by a `Modality`, and then a loss (during training). + +The `Modality` types for the various input features and for the target are +specified in `Problem.hparams`. A `Modality` is a feature adapter that enables +models to be agnostic to input/output spaces. You can see the various +`Modality`s in +[`modalities.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/modalities.py). + +The sketch structure of a T2T model is as follows: + +```python +features = {...} # output from the input pipeline +input_modaly = ... # specified in Problem.hparams +target_modality = ... # specified in Problem.hparams + +transformed_features = {} +transformed_features["inputs"] = input_modality.bottom( + features["inputs"]) +transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) # for autoregressive models + +body_outputs = model.model_fn_body(transformed_features) + +predictions = target_modality.top(body_outputs, features["targets"]) +loss = target_modality.loss(predictions, features["targets"]) +``` + +Most `T2TModel`s only override `model_fn_body`. + +## Training, Eval, Inference modes + +Both the input function and model functions take a mode in the form of a +`tf.estimator.ModeKeys`, which allows the functions to behave differently in +different modes. + +In training, the model function constructs an optimizer and minimizes the loss. + +In evaluation, the model function constructs the evaluation metrics specified by +`Problem.eval_metrics`. + +In inference, the model function outputs predictions. + +## `Estimator` and `Experiment` + +With the input function and model functions constructed, the actual training +loop and related services (checkpointing, summaries, continuous evaluation, +etc.) are all handled by `Estimator` and `Experiment` objects, constructed in +[`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py). + +## Decoding + +* [`decoding.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/decoding.py) + +TODO(rsepassi): Explain decoding (interactive, from file, and from dataset) and +`Problem.feature_encoders`. diff --git a/docs/index.md b/docs/index.md index 9394809b3..3eb7f1c61 100644 --- a/docs/index.md +++ b/docs/index.md @@ -24,11 +24,6 @@ documentation, from basic tutorials to full code documentation. ## Deep Dive -* [Life of an Example](example_life.md): how all parts of T2T are connected and work together +* [Life of an Example](example_life.md): how all parts of T2T are connected and + work together * [Distributed Training](distributed_training.md) - -## Code documentation - -See our -[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md) -for now, code docs coming. diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 57d7a03f4..0e97770ba 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -1,4 +1,4 @@ -# T2T Install and Run Walkthrough +# T2T: Tensor2Tensor Transformers [![PyPI version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/tensor2tensor) @@ -8,6 +8,26 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md) [![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby) [![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0) +[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)](https://travis-ci.org/tensorflow/tensor2tensor) + +[T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible +library and binaries for supervised learning with TensorFlow and with support +for sequence tasks. It is actively used and maintained by researchers and +engineers within the Google Brain team. You can read more about Tensor2Tensor in +the recent [Google Research Blog post introducing +it](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html). + +We're eager to collaborate with you on extending T2T, so please feel +free to [open an issue on +GitHub](https://github.com/tensorflow/tensor2tensor/issues) or +send along a pull request to add your dataset or model. +See [our contribution +doc](CONTRIBUTING.md) for details and our [open +issues](https://github.com/tensorflow/tensor2tensor/issues). +You can chat with us and other users on +[Gitter](https://gitter.im/tensor2tensor/Lobby) and please join our +[Google Group](https://groups.google.com/forum/#!forum/tensor2tensor) to keep up +with T2T announcements. Here is a one-command version that installs tensor2tensor, downloads the data, trains an English-German translation model, and evaluates it: @@ -29,10 +49,28 @@ t2t-decoder \ --problems=translate_ende_wmt32k \ --model=transformer \ --hparams_set=transformer_base_single_gpu \ - --output_dir=~/t2t_train/base + --output_dir=~/t2t_train/base \ --decode_interactive ``` +See the [Walkthrough](#walkthrough) below for more details on each step. + +### Contents + +* [Walkthrough](#walkthrough) +* [Installation](#installation) +* [Features](#features) +* [T2T Overview](#t2t-overview) + * [Datasets](#datasets) + * [Problems and Modalities](#problems-and-modalities) + * [Models](#models) + * [Hyperparameter Sets](#hyperparameter-sets) + * [Trainer](#trainer) +* [Adding your own components](#adding-your-own-components) +* [Adding a dataset](#adding-a-dataset) + +--- + ## Walkthrough Here's a walkthrough training a good English-to-German translation @@ -80,16 +118,13 @@ echo "Goodbye world" >> $DECODE_FILE BEAM_SIZE=4 ALPHA=0.6 -t2t-trainer \ +t2t-decoder \ --data_dir=$DATA_DIR \ --problems=$PROBLEM \ --model=$MODEL \ --hparams_set=$HPARAMS \ --output_dir=$TRAIN_DIR \ - --train_steps=0 \ - --eval_steps=0 \ - --decode_beam_size=$BEAM_SIZE \ - --decode_alpha=$ALPHA \ + --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \ --decode_from_file=$DECODE_FILE cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes @@ -127,3 +162,136 @@ python -c "from tensor2tensor.models.transformer import Transformer" ``` --- + +## Features + +* Many state of the art and baseline models are built-in and new models can be + added easily (open an issue or pull request!). +* Many datasets across modalities - text, audio, image - available for + generation and use, and new ones can be added easily (open an issue or pull + request for public datasets!). +* Models can be used with any dataset and input mode (or even multiple); all + modality-specific processing (e.g. embedding lookups for text tokens) is done + with `Modality` objects, which are specified per-feature in the dataset/task + specification. +* Support for multi-GPU machines and synchronous (1 master, many workers) and + asynchronous (independent workers synchronizing through a parameter server) + [distributed training](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md). +* Easily swap amongst datasets and models by command-line flag with the data + generation script `t2t-datagen` and the training script `t2t-trainer`. + +--- + +## T2T overview + +### Datasets + +**Datasets** are all standardized on `TFRecord` files with `tensorflow.Example` +protocol buffers. All datasets are registered and generated with the +[data +generator](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-datagen) +and many common sequence datasets are already available for generation and use. + +### Problems and Modalities + +**Problems** define training-time hyperparameters for the dataset and task, +mainly by setting input and output **modalities** (e.g. symbol, image, audio, +label) and vocabularies, if applicable. All problems are defined either in +[`problem_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem_hparams.py) +or are registered with `@registry.register_problem` (run `t2t-datagen` to see +the list of all available problems). +**Modalities**, defined in +[`modality.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/modality.py), +abstract away the input and output data types so that **models** may deal with +modality-independent tensors. + +### Models + +**`T2TModel`s** define the core tensor-to-tensor transformation, independent of +input/output modality or task. Models take dense tensors in and produce dense +tensors that may then be transformed in a final step by a **modality** depending +on the task (e.g. fed through a final linear transform to produce logits for a +softmax over classes). All models are imported in the +[`models` subpackage](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/__init__.py), +inherit from `T2TModel` - defined in +[`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py) - +and are registered with +[`@registry.register_model`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/registry.py). + +### Hyperparameter Sets + +**Hyperparameter sets** are defined and registered in code with +[`@registry.register_hparams`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/registry.py) +and are encoded in +[`tf.contrib.training.HParams`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py) +objects. The `HParams` are available to both the problem specification and the +model. A basic set of hyperparameters are defined in +[`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py) +and hyperparameter set functions can compose other hyperparameter set functions. + +### Trainer + +The **trainer** binary is the main entrypoint for training, evaluation, and +inference. Users can easily switch between problems, models, and hyperparameter +sets by using the `--model`, `--problems`, and `--hparams_set` flags. Specific +hyperparameters can be overridden with the `--hparams` flag. `--schedule` and +related flags control local and distributed training/evaluation +([distributed training documentation](https://github.com/tensorflow/tensor2tensor/tree/master/docs/distributed_training.md)). + +--- + +## Adding your own components + +T2T's components are registered using a central registration mechanism that +enables easily adding new ones and easily swapping amongst them by command-line +flag. You can add your own components without editing the T2T codebase by +specifying the `--t2t_usr_dir` flag in `t2t-trainer`. + +You can do so for models, hyperparameter sets, modalities, and problems. Please +do submit a pull request if your component might be useful to others. + +Here's an example with a new hyperparameter set: + +```python +# In ~/usr/t2t_usr/my_registrations.py + +from tensor2tensor.models import transformer +from tensor2tensor.utils import registry + +@registry.register_hparams +def transformer_my_very_own_hparams_set(): + hparams = transformer.transformer_base() + hparams.hidden_size = 1024 + ... +``` + +```python +# In ~/usr/t2t_usr/__init__.py +from . import my_registrations +``` + +``` +t2t-trainer --t2t_usr_dir=~/usr/t2t_usr --registry_help +``` + +You'll see under the registered HParams your +`transformer_my_very_own_hparams_set`, which you can directly use on the command +line with the `--hparams_set` flag. + +`t2t-datagen` also supports the `--t2t_usr_dir` flag for `Problem` +registrations. + +## Adding a dataset + +To add a new dataset, subclass +[`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) +and register it with `@registry.register_problem`. See +[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +for an example. + +Also see the [data generators +README](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/README.md). + +--- + +*Note: This is not an official Google product.* diff --git a/setup.py b/setup.py index a84f772b6..331abb78e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.2', + version='1.2.3', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder index 8da8ae5a2..d2fe41f2f 100644 --- a/tensor2tensor/bin/t2t-decoder +++ b/tensor2tensor/bin/t2t-decoder @@ -46,6 +46,7 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS +flags.DEFINE_string("output_dir", "", "Training directory to load from.") flags.DEFINE_string("decode_from_file", None, "Path to decode file") flags.DEFINE_string("decode_to_file", None, "Path prefix to inference output file") @@ -58,6 +59,9 @@ flags.DEFINE_string("t2t_usr_dir", "", "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-decoder.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("schedule", "train_and_evaluate", + "Must be train_and_evaluate for decoding.") def main(_): @@ -65,16 +69,18 @@ def main(_): usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) trainer_utils.log_registry() trainer_utils.validate_flags() + assert FLAGS.schedule == "train_and_evaluate" data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) hparams = trainer_utils.create_hparams( - FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) + FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) estimator, _ = trainer_utils.create_experiment_components( - hparams=hparams, - output_dir=output_dir, data_dir=data_dir, - model_name=FLAGS.model) + model_name=FLAGS.model, + hparams=hparams, + run_config=trainer_utils.create_run_config(output_dir)) decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.add_hparam("shards", FLAGS.decode_shards) diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 7c7b48932..c986522f3 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -43,6 +43,7 @@ import tensorflow as tf flags = tf.flags FLAGS = flags.FLAGS +# See trainer_utils.py for additional command-line flags. flags.DEFINE_string("t2t_usr_dir", "", "Path to a Python module that will be imported. The " "__init__.py file should include the necessary imports. " @@ -53,6 +54,12 @@ flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen", "Temporary storage directory.") flags.DEFINE_bool("generate_data", False, "Generate data before training?") +flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_string("output_dir", "", "Base output directory for run.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("schedule", "train_and_evaluate", + "Method of tf.contrib.learn.Experiment to run.") + def main(_): tf.logging.set_verbosity(tf.logging.INFO) diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 52354704d..5877b541e 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -29,6 +29,7 @@ from tensor2tensor.data_generators import image from tensor2tensor.data_generators import imdb from tensor2tensor.data_generators import lm1b +from tensor2tensor.data_generators import problem_hparams from tensor2tensor.data_generators import ptb from tensor2tensor.data_generators import snli from tensor2tensor.data_generators import wiki diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 93e846a0b..2f8e9cf30 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -129,7 +129,7 @@ def use_train_shards_for_dev(self): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: story_generator(tmp_dir)) + story_generator(tmp_dir)) for story in story_generator(tmp_dir): summary, rest = _story_summary_split(story) encoded_summary = encoder.encode(summary) + [EOS] diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py index 1e26b000c..174bd8107 100644 --- a/tensor2tensor/data_generators/desc2code.py +++ b/tensor2tensor/data_generators/desc2code.py @@ -195,8 +195,7 @@ def generator_target(): data_dir=data_dir, vocab_filename=self.vocab_target_filename, vocab_size=self.target_vocab_size, - generator_fn=generator_target, - ) + generator=generator_target(),) # Yield the training and testing samples eos_list = [EOS] diff --git a/tensor2tensor/data_generators/gene_expression.py b/tensor2tensor/data_generators/gene_expression.py index 43d5a6702..477e04017 100644 --- a/tensor2tensor/data_generators/gene_expression.py +++ b/tensor2tensor/data_generators/gene_expression.py @@ -159,17 +159,17 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): del mode # Reshape targets to contain num_output_predictions per output timestep - examples["targets"] = tf.reshape(examples["targets"], - [-1, 1, self.num_output_predictions]) + example["targets"] = tf.reshape(example["targets"], + [-1, 1, self.num_output_predictions]) # Slice off EOS - not needed, and messes up the GeneExpressionConv model # which expects the input length to be a multiple of the target length. - examples["inputs"] = examples["inputs"][:-1] + example["inputs"] = example["inputs"][:-1] - return examples + return example def eval_metrics(self): return [metrics.Metrics.LOG_POISSON, metrics.Metrics.R2] diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 3e1086d37..f22e84794 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -300,7 +300,7 @@ def gunzip_file(gz_path, new_path): def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, - generator_fn): + generator): """Inner implementation for vocab generators. Args: @@ -308,7 +308,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, then do not save the vocab even if it doesn't exist. vocab_filename: relative filename where vocab file is stored vocab_size: target size of the vocabulary constructed by SubwordTextEncoder - generator_fn: a generator that produces tokens from the vocabulary + generator: a generator that produces tokens from the vocabulary Returns: A SubwordTextEncoder vocabulary object. @@ -325,7 +325,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, tf.logging.info("Generating vocab file: %s", vocab_filepath) token_counts = defaultdict(int) - for item in generator_fn(): + for item in generator: for tok in tokenizer.encode(text_encoder.native_to_unicode(item)): token_counts[tok] += 1 @@ -382,8 +382,8 @@ def generate(): file_byte_budget -= len(line) yield line - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename, @@ -416,8 +416,8 @@ def generate(): part = parts[index].strip() yield part - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def get_or_generate_txt_vocab(data_dir, vocab_filename, vocab_size, @@ -434,8 +434,8 @@ def generate(): for line in source_file: yield line.strip() - return get_or_generate_vocab_inner( - data_dir, vocab_filename, vocab_size, generator_fn=generate) + return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, + generate()) def read_records(filename): diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 64b9d8639..084ef330a 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -91,19 +91,19 @@ class ImageCeleba(ImageProblem): "Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young" ).split() - def preprocess_examples(self, examples, unused_mode, unused_hparams): + def preprocess_example(self, example, unused_mode, unused_hparams): def resize(img, size): return tf.to_int64( tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = examples["inputs"] + inputs = example["inputs"] # Remove boundaries in CelebA images. Remove 40 pixels each side # vertically and 20 pixels each side horizontally. inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40) - examples["inputs"] = resize(inputs, 8) - examples["targets"] = resize(inputs, 32) - return examples + example["inputs"] = resize(inputs, 8) + example["targets"] = resize(inputs, 32) + return example def hparams(self, defaults, unused_model_hparams): p = defaults @@ -301,7 +301,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): self.dev_filepaths(data_dir, self.dev_shards, shuffled=False)) -def imagenet_preprocess_examples(examples, mode): +def imagenet_preprocess_example(example, mode): """Preprocessing used for Imagenet and similar problems.""" def preprocess(img): @@ -312,15 +312,15 @@ def preprocess(img): def resize(img): return tf.to_int64(tf.image.resize_images(img, [299, 299])) - inputs = tf.cast(examples["inputs"], tf.int64) + inputs = tf.cast(example["inputs"], tf.int64) if mode == tf.estimator.ModeKeys.TRAIN: - examples["inputs"] = tf.cond( # Preprocess 90% of the time. + example["inputs"] = tf.cond( # Preprocess 90% of the time. tf.less(tf.random_uniform([]), 0.9), lambda img=inputs: preprocess(img), lambda img=inputs: resize(img)) else: - examples["inputs"] = resize(inputs) - return examples + example["inputs"] = resize(inputs) + return example @registry.register_problem @@ -341,8 +341,8 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): "instructions at https://github.com/tensorflow/models/blob/master" "/inception/README.md#getting-started") - def preprocess_examples(self, examples, mode, _): - return imagenet_preprocess_examples(examples, mode) + def preprocess_example(self, example, mode, _): + return imagenet_preprocess_example(example, mode) @registry.register_problem @@ -366,17 +366,17 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): "instructions at https://github.com/tensorflow/models/blob/master" "/inception/README.md#getting-started") - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): # Just resize with area. if self._was_reversed: - examples["inputs"] = tf.to_int64( - tf.image.resize_images(examples["inputs"], [32, 32], + example["inputs"] = tf.to_int64( + tf.image.resize_images(example["inputs"], [32, 32], tf.image.ResizeMethod.AREA)) else: - examples = imagenet_preprocess_examples(examples, mode) - examples["inputs"] = tf.to_int64( - tf.image.resize_images(examples["inputs"], [32, 32])) - return examples + example = imagenet_preprocess_example(example, mode) + example["inputs"] = tf.to_int64( + tf.image.resize_images(example["inputs"], [32, 32])) + return example @registry.register_problem @@ -386,17 +386,17 @@ class Img2imgImagenet(ImageProblem): def dataset_filename(self): return "image_imagenet" # Reuse Imagenet data. - def preprocess_examples(self, examples, unused_mode, unused_hparams): + def preprocess_example(self, example, unused_mode, unused_hparams): def resize(img, size): return tf.to_int64( tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = examples["inputs"] + inputs = example["inputs"] # For Img2Img resize input and output images as desired. - examples["inputs"] = resize(inputs, 8) - examples["targets"] = resize(inputs, 32) - return examples + example["inputs"] = resize(inputs, 8) + example["targets"] = resize(inputs, 32) + return example def hparams(self, defaults, unused_model_hparams): p = defaults @@ -623,11 +623,11 @@ def class_labels(self): "ship", "truck" ] - def preprocess_examples(self, examples, mode, unused_hparams): + def preprocess_example(self, example, mode, unused_hparams): if mode == tf.estimator.ModeKeys.TRAIN: - examples["inputs"] = common_layers.cifar_image_augmentation( - examples["inputs"]) - return examples + example["inputs"] = common_layers.cifar_image_augmentation( + example["inputs"]) + return example def generator(self, data_dir, tmp_dir, is_training): if is_training: @@ -649,8 +649,8 @@ def generator(self, data_dir, tmp_dir, is_training): @registry.register_problem class ImageCifar10Plain(ImageCifar10): - def preprocess_examples(self, examples, mode, unused_hparams): - return examples + def preprocess_example(self, example, mode, unused_hparams): + return example # URLs and filenames for MSCOCO data. @@ -827,8 +827,8 @@ def train_shards(self): def dev_shards(self): return 10 - def preprocess_examples(self, examples, mode, _): - return imagenet_preprocess_examples(examples, mode) + def preprocess_example(self, example, mode, _): + return imagenet_preprocess_example(example, mode) def generator(self, data_dir, tmp_dir, is_training): if is_training: diff --git a/tensor2tensor/data_generators/imdb.py b/tensor2tensor/data_generators/imdb.py index d7eadcd1d..95d728b1e 100644 --- a/tensor2tensor/data_generators/imdb.py +++ b/tensor2tensor/data_generators/imdb.py @@ -79,7 +79,7 @@ def generator(self, data_dir, tmp_dir, train): # Generate vocab encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: self.doc_generator(imdb_dir, "train")) + self.doc_generator(imdb_dir, "train")) # Generate examples dataset = "train" if train else "test" diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py index 848b74a2d..c84f00606 100644 --- a/tensor2tensor/data_generators/inspect.py +++ b/tensor2tensor/data_generators/inspect.py @@ -67,9 +67,9 @@ def main(_): inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value] targets = [int(i) for i in x.features.feature["targets"].int64_list.value] if FLAGS.print_inputs: - print(encoder.decode(inputs) if encoder else inputs) + print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs) if FLAGS.print_targets: - print(encoder.decode(targets) if encoder else targets) + print("TARGETS:\n" + encoder.decode(targets) if encoder else targets) total_input_tokens += len(inputs) total_target_tokens += len(targets) total_sequences += 1 diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 4aa4862ef..37eee64ab 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -17,20 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - import collections import os import random - # Dependency imports - import six - from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder from tensor2tensor.utils import metrics from tensor2tensor.utils import registry - import tensorflow as tf @@ -107,16 +102,19 @@ def default_model_hparams(): data_dir=None) -def preprocess_examples_common(examples, hparams): +def preprocess_example_common(example, hparams, mode): """Preprocessing steps common to all models.""" if hparams.max_input_seq_length > 0: - examples["inputs"] = examples["inputs"][:hparams.max_input_seq_length] + example["inputs"] = example["inputs"][:hparams.max_input_seq_length] if hparams.max_target_seq_length > 0: - examples["targets"] = examples["targets"][:hparams.max_target_seq_length] + example["targets"] = example["targets"][:hparams.max_target_seq_length] if hparams.prepend_mode != "none": - examples["targets"] = tf.concat( - [examples["inputs"], [0], examples["targets"]], 0) - return examples + if mode == tf.estimator.ModeKeys.PREDICT: + example["partial_targets"] = tf.concat([example["inputs"], [0]], 0) + else: + example["targets"] = tf.concat( + [example["inputs"], [0], example["targets"]], 0) + return example class Problem(object): @@ -156,7 +154,7 @@ class Problem(object): * example_reading_spec - Specify the names and types of the features on disk. - Specify tf.contrib.slim.tfexample_decoder - * preprocess_examples(examples, mode) + * preprocess_example(example, mode) - Preprocess the example feature dict from feature name to Tensor or SparseTensor. - Used in training, eval, and inference (specified by mode). @@ -200,9 +198,8 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, mode, hparams): - del mode - return preprocess_examples_common(examples, hparams) + def preprocess_example(self, example, mode, hparams): + return preprocess_example_common(example, hparams, mode) def eval_metrics(self): return [ @@ -260,10 +257,9 @@ def get_hparams(self, model_hparams=None): if self._hparams is not None: return self._hparams - assert model_hparams is not None - if self._encoders is None: - self.get_feature_encoders(model_hparams.data_dir) + data_dir = (model_hparams and model_hparams.data_dir) or None + self.get_feature_encoders(data_dir) hp = _default_hparams() ret = self.hparams(hp, model_hparams) @@ -314,10 +310,10 @@ def dataset(self, shuffle_files: whether to shuffle input files. Default behavior (i.e. when shuffle_files=None) is to shuffle if mode == TRAIN. hparams: tf.contrib.training.HParams; hparams to be passed to - Problem.preprocess_examples and Problem.hparams. If None, will use a + Problem.preprocess_example and Problem.hparams. If None, will use a default set that is a no-op. preprocess: bool, whether to map the Dataset through - Problem.preprocess_examples. + Problem.preprocess_example. Returns: Dataset containing dict. @@ -370,7 +366,7 @@ def decode_record(record): return dict(zip(decode_items, decoded)) def _preprocess(example): - example = self.preprocess_examples(example, mode, hparams) + example = self.preprocess_example(example, mode, hparams) self.maybe_reverse_features(example) self.maybe_copy_features(example) return example @@ -385,6 +381,10 @@ def _preprocess(example): return dataset + @property + def has_inputs(self): + return "inputs" in self.get_feature_encoders() + @property def feature_info(self): """Retrieve dict. @@ -404,7 +404,8 @@ def feature_info(self): input_mods = hp.input_modality target_mod = hp.target_modality vocabs = hp.vocabulary - in_id = hp.input_space_id + if self.has_inputs: + in_id = hp.input_space_id out_id = hp.target_space_id features = collections.defaultdict(FeatureInfo) @@ -422,7 +423,8 @@ def feature_info(self): for name, encoder in six.iteritems(vocabs): features[name].encoder = encoder - features["inputs"].space_id = in_id + if self.has_inputs: + features["inputs"].space_id = in_id features["targets"].space_id = out_id self._feature_info = features diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 147fc7538..576a27a79 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -24,345 +24,185 @@ # Dependency imports +from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder from tensor2tensor.layers import modalities # pylint: disable=unused-import from tensor2tensor.utils import registry import tensorflow as tf - -def problem_hparams(problem_name, model_hparams): - """Generate problem hyperparameters based on problem name. - - Args: - problem_name: a string - model_hparams: a tf.contrib.training.HParams - - Returns: - a tf.contrib.training.HParams - """ - base_name, was_reversed, was_copy = parse_problem_name(problem_name) - p = _lookup_problem_hparams_fn(base_name)(model_hparams) - if was_reversed: - _reverse_problem_hparams(p) - if was_copy: - _copy_problem_hparams(p) - return p - - -def parse_problem_name(problem_name): - """Determines if problem_name specifies a copy and/or reversal. - - Args: - problem_name: A string containing a single problem name from FLAGS.problems. - - Returns: - base_name: A string with the base problem name. - was_reversed: A boolean. - was_copy: A boolean. - """ - # Recursively strip tags until we reach a base name. - if problem_name.endswith("_rev"): - base, _, was_copy = parse_problem_name(problem_name[:-4]) - return base, True, was_copy - elif problem_name.endswith("_copy"): - base, was_reversed, _ = parse_problem_name(problem_name[:-5]) - return base, was_reversed, True - return problem_name, False, False - - -def _lookup_problem_hparams_fn(name): - if name not in PROBLEM_HPARAMS_MAP: - map_str = "* " + "\n* ".join(sorted(PROBLEM_HPARAMS_MAP.keys())) - error_msg = "%s not in the supported set of problems:\n%s" % (name, map_str) - raise LookupError(error_msg) - return PROBLEM_HPARAMS_MAP.get(name) - - -def _copy_problem_hparams(p_hparams): - """Use input modality, vocab, and space id for target.""" - p = p_hparams - # Duplicate input modality. - p.target_modality = p.input_modality["inputs"] - # Duplicate input vocabulary. - p.vocabulary["targets"] = p.vocabulary["inputs"] - # Duplicate input space ids. - p.target_space_id = p.input_space_id - # Mark that p was reversed. - p.was_copy = True - - -def _reverse_problem_hparams(p_hparams): - """Swap input/output modalities, vocab, and space ids.""" - p = p_hparams - - # Swap modalities. - input_modality = p.input_modality["inputs"] - target_modality = p.target_modality - p.input_modality["inputs"] = target_modality - p.target_modality = input_modality - - # Swap vocabularies. - input_vocabulary = p.vocabulary["inputs"] - target_vocabulary = p.vocabulary["targets"] - p.vocabulary["inputs"] = target_vocabulary - p.vocabulary["targets"] = input_vocabulary - - # Swap input/target space ids. - input_space_id = p.input_space_id - target_space_id = p.target_space_id - p.input_space_id = target_space_id - p.target_space_id = input_space_id - - # Mark that p was reversed. - p.was_reversed = True - - -def default_problem_hparams(): - """A set of basic model hyperparameters.""" - return tf.contrib.training.HParams( - # Use this parameter to get comparable perplexity numbers with different - # tokenizations. This value should be set to the ratio of the number of - # tokens in the test set according to the tokeization used to the number - # of tokens in the test set in the "official" tokenization. For example, - # if we are using a word-piece based model and we want to compute - # per-word perplexity, then we set loss_multiplier to the number of - # wordpieces per word in the test set. - loss_multiplier=1.0, - - # Use this parameter to allow for larger sequences in the batch. Without - # the use of this parameter, the size of the inner two dimensions will be - # used to judge the sequence length. - batch_size_multiplier=1, - - # To make queues of the right capacity, it's good to know the maximal - # expected batch size, as it can vary a lot. It only affects performance - # of input readers and memory use. The defaults should be safe and fast, - # but decrease if your reader uses a lot of memory and increase if slow. - max_expected_batch_size_per_shard=64, - - # Modalities used to map from input features to a space compatible with - # chosen model architecture. One modality spec (which is a 2-tuple, - # (modality_full_name, vocab_size)) per feature key. modality_full_name is - # a string type:name, e.g. class_label:2d. Leaving off the name uses the - # default modality for that type (e.g. class_label == - # class_label:default). - input_modality={}, - - # Modality used to map from hidden representation to the target space. - # Specified as a modality spec, a 2-tuple described above. - target_modality=None, - - # Identifiers used to tell the model which input/target space will be - # expected. For example, it can tell that we expect French as characters - # as output, or Spanish as sound. An integer with the following semantics: - # 0: Generic / unknown output space (default) - # 1: Image labels - # 2: English characters - # 3: English tokens - # 4: English bpe tokens - # 5: French characters - # 6: French tokens - # 7: German characters - # 8: German tokens - # 9: German bpe tokens - # 10: Digit cipher lexicon 0 - # 11: Digit cipher lexicon 1 - # 12: Audio waveform domain - # 13: Audio spectral domain - # 14: Parse characters - # 15: Parse tokens - # 16: Chinese tokens - # 17: Icelandic characters - # 18: Icelandic tokens - # 19: Icelandic parse tokens - # 20: Macedonian tokens - # 21: Czech tokens - # 22: Czech characters - # Add more above if needed. - input_space_id=0, - target_space_id=0, - - # Vocabulary per feature key. - # a vocabulary converts to/from human-readable strings. - # E.g. {"inputs": text_encoder.ByteTextEncoder(), - # "targets": text_encoder.SubwordTextEncoder("vocab_filename.txt")} - vocabulary={ - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.TextEncoder() - }, - - # This is a marker to keep track if the problem was reversed or copied. - # Only set automatically, do not override the default. - # - # These tags can be combined in order to perform copies of the input or - # the targets. For instance `problem_copy` will copy the inputs, but - # `problem_rev_copy` will copy the targets. - was_reversed=False, - was_copy=False,) - - -def test_problem_hparams(unused_model_hparams, input_vocab_size, - target_vocab_size): +# TODO(rsepassi): Merge these problems with their data generators. Currenlty +# they only implement the hparams. + + +class AudioTimitProblem(problem.Problem): + """Base class for TIMIT problems.""" + + def example_reading_spec(self): + data_fields = { + "inputs": tf.VarLenFeature(tf.int64), + "audio/sample_count": tf.FixedLenFeature((), tf.int64), + "audio/sample_width": tf.FixedLenFeature((), tf.int64), + "targets": tf.VarLenFeature(tf.int64), + } + return data_fields, None + + def preprocess_example(self, example, mode, hparams): + example = super(AudioTimitProblem, self).preprocess_example( + example, mode, hparams) + # Reshape audio to proper shape + sample_count = tf.to_int32(example.pop("audio/sample_count")) + sample_width = tf.to_int32(example.pop("audio/sample_width")) + channel_count = 1 + example["inputs"] = tf.reshape(example["inputs"], + [sample_count, sample_width, channel_count]) + return example + + +@registry.register_problem +class AudioTimitCharactersTune(AudioTimitProblem): + """TIMIT to characters.""" + + def feature_encoders(self, _): + return { + "inputs": text_encoder.TextEncoder(), + "targets": text_encoder.ByteTextEncoder(), + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.AUDIO, None), + } + hp.target_modality = (registry.Modalities.SYMBOL, 256) + + +@registry.register_problem +class AudioTimitTokens8kTune(AudioTimitProblem): + """TIMIT to tokens.""" + + @property + def target_vocab_size(self): + return 2**13 # 8192 + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, + "vocab.endefr.%d" % self.target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": text_encoder.TextEncoder(), + "targets": subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.AUDIO, None), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.batch_size_multiplier = 256 + hp.loss_multiplier = 2.0 + hp.input_space_id = 13 + hp.target_space_id = 3 + + +@registry.register_problem +class AudioTimitTokens8kTest(AudioTimitTokens8kTune): + """TIMIT to tokens.""" + pass + + +@registry.register_problem +class ParsingEnglishPtb8k(problem.Problem): + """Parsing.""" + + @property + def target_vocab_size(self): + return 2**13 # 8192 + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, + "vocab.endefr.%d" % self.target_vocab_size) + subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) + return { + "inputs": subtokenizer, + "targets": subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, + self.get_feature_encoders()["inputs"].vocab_size), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.batch_size_multiplier = 256 + hp.loss_multiplier = 2.0 + hp.input_space_id = 3 + hp.target_space_id = 15 + + +@registry.register_problem +class ParsingEnglishPtb16k(problem.Problem): + """Parsing.""" + + @property + def vocab_prefix(self): + return "wsj" + + @property + def inputs_target_vocab_size(self): + return 2**9 # 512 + + @property + def targets_target_vocab_size(self): + return 2**14 # 16384 + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join( + data_dir, + self.vocab_prefix + "_source.vocab.%d" % self.inputs_target_vocab_size) + target_vocab_filename = os.path.join( + data_dir, + self.vocab_prefix + "_target.vocab.%d" % self.targets_target_vocab_size) + source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_subtokenizer, + "targets": target_subtokenizer, + } + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, + self.get_feature_encoders()["inputs"].vocab_size), + } + hp.target_modality = (registry.Modalities.SYMBOL, + self.get_feature_encoders()["targets"].vocab_size) + hp.input_space_id = 3 + hp.target_space_id = 15 + + +class TestProblem(problem.Problem): + """Test problem.""" + + def __init__(self, input_vocab_size, target_vocab_size): + super(TestProblem, self).__init__(False, False) + self.input_vocab_size = input_vocab_size + self.target_vocab_size = target_vocab_size + + def hparams(self, defaults, model_hparams): + hp = defaults + hp.input_modality = { + "inputs": (registry.Modalities.SYMBOL, self.input_vocab_size) + } + hp.target_modality = (registry.Modalities.SYMBOL, self.target_vocab_size) + + +def test_problem_hparams(input_vocab_size=None, target_vocab_size=None): """Problem hparams for testing model bodies.""" - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, input_vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, target_vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.TextEncoder() - } - return p - - -def audio_timit_characters(unused_model_hparams): - """English audio transcription benchmark.""" - p = default_problem_hparams() - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.batch_size_multiplier = 256 - p.loss_multiplier = 2.0 - p.input_space_id = 12 - p.target_space_id = 2 - return p - - -def audio_timit_tokens(model_hparams, wrong_vocab_size): - """English audio transcription benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "vocab.endefr.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.AUDIO, None), - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": text_encoder.TextEncoder(), - "targets": subtokenizer, - } - p.batch_size_multiplier = 256 - p.loss_multiplier = 2.0 - p.input_space_id = 13 - p.target_space_id = 3 - return p - - -def wmt_parsing_characters(model_hparams): - """English to parse tree translation benchmark.""" - del model_hparams # Unused. - p = default_problem_hparams() - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, 256)} - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": text_encoder.ByteTextEncoder(), - "targets": text_encoder.ByteTextEncoder(), - } - p.loss_multiplier = 2.0 - p.input_space_id = 2 - p.target_space_id = 14 - return p - - -def wmt_parsing_tokens(model_hparams, wrong_vocab_size): - """English to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_vocab_size: a number used in the filename indicating the approximate - vocabulary size. This is not to be confused with the actual vocabulary - size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - vocab_filename = os.path.join(model_hparams.data_dir, - "vocab.endefr.%d" % wrong_vocab_size) - subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) - p.vocabulary = { - "inputs": subtokenizer, - "targets": subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 15 - return p - - -def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size, - wrong_target_vocab_size): - """English to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - prefix: name to use as prefix for vocabulary files. - wrong_source_vocab_size: a number used in the filename indicating the - approximate vocabulary size. This is not to be confused with the actual - vocabulary size. - wrong_target_vocab_size: a number used in the filename indicating the - approximate target vocabulary size. This is not to be confused with the - actual target vocabulary size. - Returns: - a tf.contrib.training.HParams - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - source_vocab_filename = os.path.join( - model_hparams.data_dir, - prefix + "_source.vocab.%d" % wrong_source_vocab_size) - target_vocab_filename = os.path.join( - model_hparams.data_dir, - prefix + "_target.vocab.%d" % wrong_target_vocab_size) - source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, - target_subtokenizer.vocab_size) - p.vocabulary = { - "inputs": source_subtokenizer, - "targets": target_subtokenizer, - } - p.input_space_id = 3 - p.target_space_id = 15 - return p - - -# Dictionary of named hyperparameter settings for various problems. -# This is only accessed through the problem_hparams function below. -PROBLEM_HPARAMS_MAP = { - "audio_timit_characters_tune": - audio_timit_characters, - "audio_timit_characters_test": - audio_timit_characters, - "audio_timit_tokens_8k_tune": - lambda p: audio_timit_tokens(p, 2**13), - "audio_timit_tokens_8k_test": - lambda p: audio_timit_tokens(p, 2**13), - "parsing_english_ptb8k": - lambda p: wmt_parsing_tokens(p, 2**13), - "parsing_english_ptb16k": - lambda p: wsj_parsing_tokens( # pylint: disable=g-long-lambda - p, "wsj", 2**14, 2**9), -} + p = TestProblem(input_vocab_size, target_vocab_size) + return p.get_hparams() diff --git a/tensor2tensor/data_generators/problem_hparams_test.py b/tensor2tensor/data_generators/problem_hparams_test.py deleted file mode 100644 index df92919ef..000000000 --- a/tensor2tensor/data_generators/problem_hparams_test.py +++ /dev/null @@ -1,50 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for tensor2tensor.problem_hparams.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from tensor2tensor.data_generators import problem_hparams - -import tensorflow as tf - - -class ProblemHparamsTest(tf.test.TestCase): - - def testParseProblemName(self): - problem_name = "base" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", False, - False)) - problem_name = "base_rev" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, False)) - problem_name = "base_copy" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", False, True)) - problem_name = "base_copy_rev" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, True)) - problem_name = "base_rev_copy" - self.assertEqual( - problem_hparams.parse_problem_name(problem_name), ("base", True, True)) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index 893c2b77c..31bc83c0a 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -42,9 +42,9 @@ def _read_words(filename): """Reads words from a file.""" with tf.gfile.GFile(filename, "r") as f: if sys.version_info[0] >= 3: - return f.read().replace("\n", " ").split() + return f.read().replace("\n", " %s " % EOS).split() else: - return f.read().decode("utf-8").replace("\n", " ").split() + return f.read().decode("utf-8").replace("\n", " %s " % EOS).split() def _build_vocab(filename, vocab_path, vocab_size): @@ -151,7 +151,7 @@ def generator(self, data_dir, tmp_dir, train): def _generator(self, filename, encoder): with tf.gfile.GFile(filename, "r") as f: for line in f: - line = " ".join(line.replace("\n", EOS).split()) + line = " ".join(line.replace("\n", " %s " % EOS).split()) tok = encoder.encode(line) if tok: yield {"inputs": [0], "targets": tok} diff --git a/tensor2tensor/data_generators/text_encoder_test.py b/tensor2tensor/data_generators/text_encoder_test.py index b55a51bf4..0351d0d2f 100644 --- a/tensor2tensor/data_generators/text_encoder_test.py +++ b/tensor2tensor/data_generators/text_encoder_test.py @@ -107,6 +107,13 @@ def test_reserved_tokens_in_corpus(self): class SubwordTextEncoderTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + """Make sure the test dir exists and is empty.""" + cls.test_temp_dir = os.path.join(tf.test.get_temp_dir(), "encoder_test") + shutil.rmtree(cls.test_temp_dir, ignore_errors=True) + os.mkdir(cls.test_temp_dir) + def test_encode_decode(self): corpus = ( "This is a corpus of text that provides a bunch of tokens from which " @@ -216,6 +223,28 @@ def test_load_from_file(self): encoder._load_from_file_object(vocab) self.assertEqual(encoder._all_subtoken_strings, correct_vocab) + def test_reserved_token_chars_not_in_alphabet(self): + corpus = "dog" + token_counts = collections.Counter(corpus.split(" ")) + encoder1 = text_encoder.SubwordTextEncoder.build_to_target_size( + 100, token_counts, 2, 100) + filename = os.path.join(self.test_temp_dir, "out.voc") + encoder1.store_to_file(filename) + encoder2 = text_encoder.SubwordTextEncoder(filename=filename) + + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoder1 can encode all reserved token chars. + encoder1.encode(c) + + # TODO(seabass): Implement the fix so that we can remove this assertion. + with self.assertRaises(AssertionError): + for t in text_encoder.RESERVED_TOKENS: + for c in t: + # Verify that encoder2 fails to encode the characters (i.e. + # reproduce the bug). + encoder2.encode(c) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 6f6c97686..a1380c27f 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -31,6 +31,7 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import metrics from tensor2tensor.utils import registry import tensorflow as tf @@ -126,7 +127,7 @@ def use_train_shards_for_dev(self): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: page_generator(tmp_dir, max_docs=10000)) + page_generator(tmp_dir, max_docs=10000)) for page in page_generator(tmp_dir): title = _page_title(page) encoded = encoder.encode(page) + [EOS] @@ -209,7 +210,7 @@ def scramble(self, seq): def generator(self, data_dir, tmp_dir, _): encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - lambda: page_generator(tmp_dir, max_docs=1000)) + page_generator(tmp_dir, max_docs=1000)) case_num = 0 for page in page_generator(tmp_dir): encoded = encoder.encode(page) @@ -222,6 +223,24 @@ def generator(self, data_dir, tmp_dir, _): inputs = self.scramble(targets) yield {"inputs": inputs, "targets": targets} + def eval_metrics(self): + return [ + metrics.Metrics.ACC, metrics.Metrics.NEG_LOG_PERPLEXITY + ] + + +@registry.register_problem +class LanguagemodelWikiScramble128(LanguagemodelWikiScramble): + """Sequence length 128, 50% scrambed.""" + + @property + def sequence_length(self): + return 128 + + @property + def scramble_fraction(self): + return 0.5 + @registry.register_problem class LanguagemodelWikiScramble1k50(LanguagemodelWikiScramble): diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index befb9ac7f..cde0bc9ac 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -34,7 +34,6 @@ FLAGS = tf.flags.FLAGS - # End-of-sentence marker. EOS = text_encoder.EOS_ID @@ -186,7 +185,6 @@ def bi_vocabs_token_generator(source_path, # Data-set URLs. - _ENDE_TRAIN_DATASETS = [ [ "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long @@ -287,7 +285,6 @@ def bi_vocabs_token_generator(source_path, ], ] - # Generators. @@ -333,8 +330,8 @@ def generator(self, data_dir, tmp_dir, train): with tf.gfile.GFile(token_path, mode="a") as f: f.write("UNK\n") # Add UNK to the vocab. token_vocab = text_encoder.TokenTextEncoder(token_path, replace_oov="UNK") - return token_generator(train_path + ".en", train_path + ".de", - token_vocab, EOS) + return token_generator(train_path + ".en", train_path + ".de", token_vocab, + EOS) @property def input_space_id(self): @@ -360,7 +357,7 @@ def _preprocess_sgm(line, is_sgm): line = line.strip() if line.startswith(""): i = line.index(">") - return line[i+1:-6] # Strip first and last . + return line[i + 1:-6] # Strip first and last . def _compile_data(tmp_dir, datasets, filename): @@ -479,18 +476,24 @@ def targeted_vocab_size(self): def num_shards(self): return 10 # This is a small dataset. + @property + def source_vocab_name(self): + return "vocab.zhen-zh.%d" % self.targeted_vocab_size + + @property + def target_vocab_name(self): + return "vocab.zhen-en.%d" % self.targeted_vocab_size + def generator(self, data_dir, tmp_dir, train): - source_vocab_size = self.targeted_vocab_size - target_vocab_size = self.targeted_vocab_size datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS] target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS] source_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.zhen-zh.%d" % source_vocab_size, - source_vocab_size, source_datasets) + data_dir, tmp_dir, self.source_vocab_name, self.targeted_vocab_size, + source_datasets) target_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.zhen-en.%d" % target_vocab_size, - target_vocab_size, target_datasets) + data_dir, tmp_dir, self.target_vocab_name, self.targeted_vocab_size, + target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) # We generate English->X data by convention, to train reverse translation @@ -508,11 +511,8 @@ def target_space_id(self): return problem.SpaceID.EN_TOK def feature_encoders(self, data_dir): - vocab_size = self.targeted_vocab_size - source_vocab_filename = os.path.join(data_dir, - "vocab.zhen-zh.%d" % vocab_size) - target_vocab_filename = os.path.join(data_dir, - "vocab.zhen-en.%d" % vocab_size) + source_vocab_filename = os.path.join(data_dir, self.source_vocab_name) + target_vocab_filename = os.path.join(data_dir, self.target_vocab_name) source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) return { diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 6f7c9fa23..582f8e9b3 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -37,6 +37,51 @@ _expert_count = 0 +def get_timing_signal_1d( + length, channels, min_timescale=1.0, max_timescale=1.0e4): + """Gets a bunch of sinusoids of different frequencies. + + Each channel of the input Tensor is incremented by a sinusoid of a different + frequency and phase. + + This allows attention to learn to use absolute and relative positions. + Timing signals should be added to some precursors of both the query and the + memory inputs to attention. + + The use of relative position is possible because sin(x+y) and cos(x+y) can be + experessed in terms of y, sin(x) and cos(x). + + In particular, we use a geometric sequence of timescales starting with + min_timescale and ending with max_timescale. The number of different + timescales is equal to channels / 2. For each timescale, we + generate the two sinusoidal signals sin(timestep/timescale) and + cos(timestep/timescale). All of these sinusoids are concatenated in + the channels dimension. + + Args: + length: scalar, length of timing signal sequence. + channels: scalar, size of timing embeddings to create. The number of + different timescales is equal to channels / 2. + min_timescale: a float + max_timescale: a float + + Returns: + a Tensor of timing signals [1, length, channels] + """ + position = tf.to_float(tf.range(length)) + num_timescales = channels // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (tf.to_float(num_timescales) - 1)) + inv_timescales = min_timescale * tf.exp( + tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) + scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) + signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) + signal = tf.reshape(signal, [1, length, channels]) + return signal + + def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): """Adds a bunch of sinusoids of different frequencies to a Tensor. @@ -67,17 +112,34 @@ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): """ length = tf.shape(x)[1] channels = tf.shape(x)[2] - position = tf.to_float(tf.range(length)) + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal + + +def add_timing_signal_1d_given_position(x, position, min_timescale=1.0, + max_timescale=1.0e4): + """Adds sinusoids of diff frequencies to a Tensor, with timing position given. + + Args: + x: a Tensor with shape [batch, length, channels] + position: a Tensor with shape [batch, length] + min_timescale: a float + max_timescale: a float + + Returns: + a Tensor the same shape as x. + """ + channels = tf.shape(x)[2] num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) - scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) - signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) - signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) - signal = tf.reshape(signal, [1, length, channels]) + scaled_time = (tf.expand_dims(tf.to_float(position), 2) * + tf.expand_dims(tf.expand_dims(inv_timescales, 0), 0)) + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) + signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(channels, 2)]]) return x + signal @@ -189,18 +251,42 @@ def embedding_to_padding(emb): return tf.to_float(tf.equal(emb_sum, 0.0)) +def attention_bias_local(length, max_backward, max_forward): + """Create an bias tensor to be added to attention logits. + + A position may attend to positions at most max_distance from it, + forward and backwards. + + This does not actually save any computation. + + Args: + length: an integer Scalar. + max_backward: an int64 Scalar - maximum distance backward to attend. + negative values indicate unlimited. + max_forward: an int64 Scalar - maximum distance forward to attend. + negative values indicate unlimited. + + Returns: + a `Tensor` with shape [1, 1, length, length]. + """ + band = tf.matrix_band_part( + tf.ones([length, length]), max_backward, max_forward) + ret = -1e9 * (1.0 - band) + return tf.reshape(ret, [1, 1, length, length]) + + def attention_bias_lower_triangle(length): """Create an bias tensor to be added to attention logits. + Allows a query to attend to all positions up to and including its own. + Args: length: a Scalar. Returns: a `Tensor` with shape [1, 1, length, length]. """ - lower_triangle = tf.matrix_band_part(tf.ones([length, length]), -1, 0) - ret = -1e9 * (1.0 - lower_triangle) - return tf.reshape(ret, [1, 1, length, length]) + return attention_bias_local(length, -1, 0) def attention_bias_ignore_padding(memory_padding): @@ -416,7 +502,8 @@ def dot_product_attention(q, bias, dropout_rate=0.0, image_shapes=None, - name=None): + name=None, + make_image_summary=True): """dot-product attention. Args: @@ -428,6 +515,7 @@ def dot_product_attention(q, image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string + make_image_summary: True if you want an image summary. Returns: A Tensor. @@ -443,7 +531,8 @@ def dot_product_attention(q, weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if (not tf.get_variable_scope().reuse and # Summaries don't work well within tf.while_loop() - "/while/" not in tf.contrib.framework.get_name_scope()): + "/while/" not in tf.contrib.framework.get_name_scope() and + make_image_summary): attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) @@ -479,7 +568,7 @@ def masked_local_attention_1d( # If (length < 2 * block_length), then we use only one block. block_length = tf.where(tf.less(length, block_length * 2), length, block_length) - depth_k = tf.shape(q)[3] + depth_k = tf.shape(k)[3] depth_v = tf.shape(v)[3] original_length = length padding_size = tf.mod(-length, block_length) @@ -616,11 +705,9 @@ def pad_l_and_r(x, pad_length): v_new = tf.gather(v_t, gather_indices) v_new = tf.transpose(v_new, [2, 3, 0, 1, 4]) - logits = tf.matmul(q, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) - + output = dot_product_attention( + q, k_new, v_new, attention_bias, dropout_rate=0., name="local_1d", + make_image_summary=False) output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) # Remove the padding if introduced output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) @@ -650,16 +737,13 @@ def local_attention_2d(q, """ with tf.variable_scope( name, default_name="local_self_attention_2d", values=[q, k, v]): + q_shape = q.get_shape().as_list() v_shape = tf.shape(v) - depth_v = tf.shape(v)[4] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] * tf.shape(q)[3] q = pad_to_multiple_2d(q, query_shape) k = pad_to_multiple_2d(k, query_shape) v = pad_to_multiple_2d(v, query_shape) - + padded_q_shape = tf.shape(q) # Setting up k and v values paddings = [[0, 0], [0, 0], [memory_flange[0], memory_flange[1]], [memory_flange[0], memory_flange[1]], [0, 0]] @@ -680,16 +764,16 @@ def local_attention_2d(q, attention_bias = tf.expand_dims( tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) - logits = tf.matmul(q_new, k_new, transpose_b=True) - - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) - - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) - # Remove the padding if introduced - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - # [batch, heads, h, w, depth_v] - return tf.reshape(output, v_shape) + output = dot_product_attention(q_new, k_new, v_new, attention_bias, + dropout_rate=0., name="local_2d", + make_image_summary=False) + # putting the representations back in the right place + output = scatter_blocks_2d(output, q_indices, padded_q_shape) + # Remove the padding if introduced + output = tf.slice(output, [0, 0, 0, 0, 0], + [-1, -1, v_shape[2], v_shape[3], -1]) + output.set_shape(q_shape) + return output def pad_to_multiple_2d(x, block_shape): @@ -726,6 +810,19 @@ def gather_blocks_2d(x, indices): return tf.transpose(x_new, [2, 3, 0, 1, 4]) +def scatter_blocks_2d(x, indices, shape): + """scatters blocks from x into shape with indices.""" + x_shape = tf.shape(x) + # [length, batch, heads, dim] + x_t = tf.transpose(tf.reshape(x, [x_shape[0], x_shape[1], -1, x_shape[-1]]), + [2, 0, 1, 3]) + x_t_shape = tf.shape(x_t) + indices = tf.reshape(indices, [-1, 1]) + scattered_x = tf.scatter_nd(indices, x_t, x_t_shape) + scattered_x = tf.transpose(scattered_x, [1, 2, 0, 3]) + return tf.reshape(scattered_x, shape) + + def gather_indices_2d(x, block_shape, block_stride): """Getting gather indices.""" # making an identity matrix kernel @@ -745,6 +842,42 @@ def gather_indices_2d(x, block_shape, block_stride): return tf.cast(indices, tf.int32) +def make_2d_block_raster_mask(query_shape, memory_flange): + """creates a mask for 2d block raster scany. + + The query mask can look to the left, top left, top, and top right, but + not to the right. Inside the query, we have the standard raster scan + masking. + Args: + query_shape: A tuple of ints (query_height, query_width) + memory_flange: A tuple of ints + (memory_flange_height, memory_flange_width) + + Returns: + A tensor of shape query_size, memory_size + """ + # mask inside the query block + query_triangle = tf.matrix_band_part( + tf.ones([np.prod(query_shape), np.prod(query_shape)]), -1, 0) + split_query_masks = tf.split(query_triangle, query_shape[0], axis=1) + # adding mask for left and right + mask_pieces = [ + tf.concat( + [tf.ones([np.prod(query_shape), memory_flange[1]]), + split_query_masks[i], + tf.zeros([np.prod(query_shape), memory_flange[1]]) + ], axis=1) for i in range(query_shape[0])] + # adding mask for top + final_mask = tf.concat( + [tf.ones( + [np.prod(query_shape), + (query_shape[1]+2*memory_flange[1])*memory_flange[0]]), + tf.concat(mask_pieces, axis=1) + ], axis=1) + # 0. is visible location, 1.0 is masked. + return 1. - final_mask + + def masked_local_attention_2d(q, k, v, @@ -769,45 +902,11 @@ def masked_local_attention_2d(q, """ with tf.variable_scope( name, default_name="local_masked_self_attention_2d", values=[q, k, v]): + q_shape = q.get_shape().as_list() v_shape = tf.shape(v) - depth_v = tf.shape(v)[4] - batch_size = tf.shape(q)[0] - num_heads = tf.shape(q)[1] - original_length = tf.shape(q)[2] * tf.shape(q)[3] - def make_mask(query_shape, memory_flange): - """creates a mask. - - The query mask can look to the left, top left, top, and top right, but - not the right. Inside the query, we have the standard raster scan - masking. - Args: - query_shape: A tuple of ints (query_height, query_width) - memory_flange: A tuple of ints - (memory_flange_height, memory_flange_width) - - Returns: - A tensor of shape query_size, memory_size - """ - - query_triangle = tf.matrix_band_part( - tf.ones([np.prod(query_shape), np.prod(query_shape)]), -1, 0) - split_query_masks = tf.split(query_triangle, query_shape[0], axis=1) - mask_pieces = [ - tf.concat( - [tf.ones([np.prod(query_shape), memory_flange[1]]), - split_query_masks[i], - tf.zeros([np.prod(query_shape), memory_flange[1]]) - ], axis=1) for i in range(query_shape[0])] - - final_mask = tf.concat( - [tf.ones( - [np.prod(query_shape), - (query_shape[1]+2*memory_flange[1])*memory_flange[0]]), - tf.concat(mask_pieces, axis=1) - ], axis=1) - # 0. is visible location, 1.0 is masked. - return 1. - final_mask + q = pad_to_multiple_2d(q, query_shape) + padded_q_shape = tf.shape(q) k = pad_to_multiple_2d(k, query_shape) v = pad_to_multiple_2d(v, query_shape) # Setting up k and v values. Padding top, left, and right @@ -824,25 +923,28 @@ def make_mask(query_shape, memory_flange): k_and_v_indices = gather_indices_2d(k, memory_shape, query_shape) k_new = gather_blocks_2d(k, k_and_v_indices) v_new = gather_blocks_2d(v, k_and_v_indices) - logits = tf.matmul(q_new, k_new, transpose_b=True) # Combining the mask for padding and visible region attention_mask_shape = [np.prod(query_shape), (query_shape[0]+memory_flange[0])* (query_shape[1]+2*memory_flange[1])] - attention_mask = tf.cast(make_mask(query_shape, memory_flange), tf.bool) + attention_mask = tf.cast( + make_2d_block_raster_mask(query_shape, memory_flange), tf.bool) # reshaping attention mask to have same dims as logits attention_mask = tf.reshape(attention_mask, [1, 1, 1]+attention_mask_shape) padding_mask = tf.expand_dims( tf.cast(embedding_to_padding(k_new), tf.bool), axis=-2) attention_bias = ( tf.to_float(tf.logical_or(attention_mask, padding_mask)) *-1e9) - attention = tf.nn.softmax(logits + attention_bias) - output = tf.matmul(attention, v_new) - output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) + output = dot_product_attention(q_new, k_new, v_new, attention_bias, + dropout_rate=0., name="masked_local_2d", + make_image_summary=False) + # putting the representations back in the right place + output = scatter_blocks_2d(output, q_indices, padded_q_shape) # Remove the padding if introduced - output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) - # [batch, heads, h, w, depth_v] - return tf.reshape(output, v_shape) + output = tf.slice(output, [0, 0, 0, 0, 0], + [-1, -1, v_shape[2], v_shape[3], -1]) + output.set_shape(q_shape) + return output def compute_qkv(query_antecedent, memory_antecedent, total_key_depth, @@ -962,6 +1064,7 @@ def multihead_attention(query_antecedent, kv_filter_width=1, q_padding="VALID", kv_padding="VALID", + cache=None, name=None): """Multihead scaled-dot-product attention with input/output transformations. @@ -985,11 +1088,28 @@ def multihead_attention(query_antecedent, to be. q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. kv_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding. - + cache: dict, containing Tensors which are the results of previous + attentions, used for fast decoding. Expects the dict to contrain two + keys; 'k' and 'v', for the initial call the values for these keys should + be empty Tensors of the appropriate shape. + 'k' [batch_size, 0, key_channels] + 'v' [batch_size, 0, value_channels] name: an optional string + Caching: + WARNING: For decoder self-attention, i.e. when memory_antecedent == None, + the caching assumes that the bias contains future masking. + + The caching works by saving all the previous key and value values so that + you are able to send just the last query location to this attention + function. I.e. if the cache dict is provided it assumes the query is of the + shape [batch_size, 1, hiddem_dim] rather than the full memory. + Returns: - A Tensor. + The result of the attention transformation. The output shape is + [batch_size, length_q, hidden_dim] + unless the cache dict is provided in which case only the last memory + position is calculated and the output shape is [batch_size, 1, hidden_dim] Raises: ValueError: if the key depth or value depth are not divisible by the @@ -1009,6 +1129,17 @@ def multihead_attention(query_antecedent, total_value_depth, q_filter_width, kv_filter_width, q_padding, kv_padding) + if cache is not None: + if attention_type != "dot_product": + raise NotImplementedError( + "Caching is not guaranteed to work with attention types other than" + " dot_product.") + if bias is None: + raise ValueError("Bias required for caching. See function docstring " + "for details.") + k = cache["k"] = tf.concat([cache["k"], k], axis=1) + v = cache["v"] = tf.concat([cache["v"], v], axis=1) + q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index d8f6f2b39..7823936fa 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -162,5 +162,88 @@ def testMultiheadSelfAttentionMemoryEfficient(self): self.assertAllClose(dnorm_bias, dnorm_bias_f) self.assertAllClose(dx, dx_f) + def test2dGatherAndScatterInvertibility(self): + """2d gather and scatter invertibility test.""" + batch_size = 2 + num_heads = 2 + height = 4 + width = 6 + depth = 8 + query_shape = (2, 3) + x = np.random.rand(batch_size, num_heads, height, width, depth) + with self.test_session() as session: + x_indices = common_attention.gather_indices_2d( + x, query_shape, query_shape) + gathered_x = common_attention.gather_blocks_2d(x, x_indices) + x_shape = tf.constant([batch_size, num_heads, height, width, depth]) + scattered_x = common_attention.scatter_blocks_2d( + gathered_x, x_indices, x_shape) + session.run(tf.global_variables_initializer()) + res = session.run(scattered_x) + self.assertAllClose(x, res) + + def test2dBlockRasterScanMask(self): + """Testing the 2d block raster scan mask.""" + query_shape = (2, 3) + memory_flange = (2, 1) + with self.test_session() as session: + mask = common_attention.make_2d_block_raster_mask( + query_shape, memory_flange) + res = session.run(mask) + correct_mask = np.array( + [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]]) + self.assertAllClose(correct_mask, res) + + def test2dGather(self): + """Testing 2d index gather and block gather functions.""" + batch_size = 2 + num_heads = 2 + height = 4 + width = 6 + depth = 8 + query_shape = (2, 3) + x = np.random.rand(batch_size, num_heads, height, width, depth) + y = np.reshape(x, (batch_size, num_heads, -1, depth)) + correct_indices = [[0, 1, 2, 6, 7, 8], + [3, 4, 5, 9, 10, 11], + [12, 13, 14, 18, 19, 20], + [15, 16, 17, 21, 22, 23]] + correct_gathered_x = [[[y[0, 0, correct_indices[0]], + y[0, 0, correct_indices[1]], + y[0, 0, correct_indices[2]], + y[0, 0, correct_indices[3]]], + [y[0, 1, correct_indices[0]], + y[0, 1, correct_indices[1]], + y[0, 1, correct_indices[2]], + y[0, 1, correct_indices[3]]]], + [[y[1, 0, correct_indices[0]], + y[1, 0, correct_indices[1]], + y[1, 0, correct_indices[2]], + y[1, 0, correct_indices[3]]], + [y[1, 1, correct_indices[0]], + y[1, 1, correct_indices[1]], + y[1, 1, correct_indices[2]], + y[1, 1, correct_indices[3]]]]] + + with self.test_session() as session: + x_indices = common_attention.gather_indices_2d( + x, query_shape, query_shape) + gathered_x = common_attention.gather_blocks_2d(x, x_indices) + x_indices, gathered_x = session.run([x_indices, gathered_x]) + self.assertAllEqual(correct_indices, x_indices) + self.assertAllClose(correct_gathered_x, gathered_x) + + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 2e33c9e94..deae14ddc 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -126,13 +126,13 @@ def basic_params1(): # The maximum length of "input" sequence. # Sequences longer than this value will be truncated. 0 or negative values # mean there is no maximum or truncation. - # You can change this behavior by overridding preprocess_examples() method + # You can change this behavior by overridding preprocess_example() method # in your problem class. max_input_seq_length=0, # The maximum length of "target" sequence. # Sequences longer than this value will be truncated. 0 or negative values # mean there is no maximum or truncation. - # You can change this behavior by overridding preprocess_examples() method + # You can change this behavior by overridding preprocess_example() method # in your problem class. max_target_seq_length=0, # This flag allows us to optionally treat a seq-to-seq problem @@ -152,8 +152,7 @@ def basic_params1(): # position in the inputs portion can see the # entire inputs portion. This removes the challenge of # autoregressively predicting the inputs portion. - prepend_mode="none", - ) + prepend_mode="none",) class RangedHParams(object): diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index bd9ff896d..6554e0d31 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -209,7 +209,7 @@ def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0): return tf.reshape(emb_x, [shape[0], shape[1], shape[2], static_shape[4]]) -def shift_left(x, pad_value=None): +def shift_right(x, pad_value=None): """Shift the second dimension of x right by one.""" if pad_value is None: shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])[:, :-1, :, :] @@ -218,7 +218,7 @@ def shift_left(x, pad_value=None): return shifted_targets -def shift_left_3d(x, pad_value=None): +def shift_right_3d(x, pad_value=None): """Shift the second dimension of x right by one.""" if pad_value is None: shifted_targets = tf.pad(x, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] @@ -815,7 +815,7 @@ def decompress_seqcnn(x, # Flatten x and embedded targets. Flat targets are factor* larger on axis=1. flat_x = tf.reshape(x, [-1, 1, 1, hidden_size]) flat_targets = tf.reshape(targets_emb, [-1, factor, 1, hidden_size]) - shifted_targets = shift_left(flat_targets) + shifted_targets = shift_right(flat_targets) # Run a SeqCNN large-batch to produce factor outputs out of every target. flat_x += tf.zeros_like(shifted_targets) # Broadcast on axis=1. flat_outputs = conv_block( diff --git a/tensor2tensor/layers/common_layers_test.py b/tensor2tensor/layers/common_layers_test.py index d11f8ce2c..ee07c48d3 100644 --- a/tensor2tensor/layers/common_layers_test.py +++ b/tensor2tensor/layers/common_layers_test.py @@ -281,7 +281,7 @@ def testShiftLeft(self): expected = np.zeros((5, 7, 1, 11)) expected[:, 1, :] = np.ones_like(expected[:, 1, :]) with self.test_session() as session: - a = common_layers.shift_left(tf.constant(x1, dtype=tf.float32)) + a = common_layers.shift_right(tf.constant(x1, dtype=tf.float32)) actual = session.run(a) self.assertAllEqual(actual, expected) diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 8502e0a8b..8d1206ee8 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -18,11 +18,15 @@ From [The Reversible Residual Network: Backpropagation Without Storing Activations](https://arxiv.org/abs/1707.04585). + +Also contains the @recompute_grad decorator, which recomputes the forward +function on the backwards pass. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import re # Dependency imports @@ -286,8 +290,8 @@ def custom_grad_fn(inputs, variables, ys, grad_ys): # idxs. f_var_grads.reverse() g_var_grads.reverse() - for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(zip( - g_vars_idxs, g_var_grads)): + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): for i, grad in zip(idxs, grads): variable_grads[i] = grad @@ -316,3 +320,42 @@ def forward(x1, x2, *side_inputs): gate_outputs=is_training) return forward(x1, x2, *(f_side_input + g_side_input)) + + +def recompute_grad(fn): + """Decorator that recomputes the function on the backwards pass. + + Args: + fn: a function that takes Tensors (all as positional arguments) and returns + a tuple of Tensors. + + Returns: + A wrapped fn that is identical to fn when called, but its activations will + be discarded and recomputed on the backwards pass (i.e. on a call to + tf.gradients). + """ + + @functools.wraps(fn) + def wrapped(*args): + return _recompute_grad(fn, args) + + return wrapped + + +def _recompute_grad(fn, args): + """See recompute_grad.""" + + def grad_fn(inputs, variables, outputs, output_grads): + del outputs + # recompute outputs + outputs = list(fn(*inputs)) + grads = tf.gradients(outputs, inputs + variables, output_grads) + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + @common_layers.fn_with_custom_grad(grad_fn) + def fn_with_recompute(*args): + return fn(*args) + + return fn_with_recompute(*args) diff --git a/tensor2tensor/layers/rev_block_test.py b/tensor2tensor/layers/rev_block_test.py index 5aecc8ea3..3e5f7c932 100644 --- a/tensor2tensor/layers/rev_block_test.py +++ b/tensor2tensor/layers/rev_block_test.py @@ -137,5 +137,31 @@ def f(x): self._testRevBlock(x=x, f=f) +class RecomputeTest(tf.test.TestCase): + + def testRecompute(self): + + @rev_block.recompute_grad + def fn_recompute(x, y): + return x + y, x**y + + def fn(x, y): + return x + y, x**y + + x = tf.ones((3, 3)) + y = tf.ones((3, 3)) + out1 = tf.reduce_sum(fn_recompute(x, y)) + out2 = tf.reduce_sum(fn(x, y)) + + grad1 = tf.gradients(out1, [x, y]) + grad2 = tf.gradients(out2, [x, y]) + + with self.test_session() as sess: + outs = sess.run([out1, out2, grad1, grad2]) + self.assertAllClose(outs[0], outs[1]) + for g1, g2 in zip(outs[2], outs[3]): + self.assertAllClose(g1, g2) + + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/models/__init__.py b/tensor2tensor/models/__init__.py index acebef809..f5fafe706 100644 --- a/tensor2tensor/models/__init__.py +++ b/tensor2tensor/models/__init__.py @@ -23,6 +23,7 @@ # pylint: disable=unused-import from tensor2tensor.layers import modalities +from tensor2tensor.models import aligned from tensor2tensor.models import attention_lm from tensor2tensor.models import attention_lm_moe from tensor2tensor.models import bluenet diff --git a/tensor2tensor/models/aligned.py b/tensor2tensor/models/aligned.py new file mode 100644 index 000000000..90100c842 --- /dev/null +++ b/tensor2tensor/models/aligned.py @@ -0,0 +1,446 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Single stack of transformations with no masking. + +Produces output aligned with inputs. + +Configurable using hyperparameters to use some combination of convolutions, +attention, mixtures of experts, etc. + +A good problem for this model is languagemodel_wiki_scramble1k50 . +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_hparams +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import diet +from tensor2tensor.utils import expert_utils +from tensor2tensor.utils import registry +from tensor2tensor.utils import t2t_model + +import tensorflow as tf + + +ModeKeys = tf.estimator.ModeKeys # pylint: disable=invalid-name + + +def _should_preprocess(layer_type): + return layer_type not in [ + "timing", "pos_emb", "att_memory_efficient"] + + +def _should_postprocess(layer_type): + return layer_type not in ["timing", "pos_emb"] + + +@registry.register_model +class Aligned(t2t_model.T2TModel): + """Attention net. See file docstring.""" + + def model_fn_body_sharded(self, sharded_features): + # Remove dropout if not training + hparams = self._hparams + dp = self._data_parallelism + x = dp(tf.squeeze, sharded_features["inputs"], 2) + def preprocess(x): + return dp(common_layers.layer_preprocess, x, hparams) + def postprocess(x, y): + return dp(common_layers.layer_postprocess, x, y, hparams) + x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) + extra_loss = 0.0 + ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + if hparams.diet_experts: + hsize, = moe_hidden_sizes + + def _diet_expert(x): + return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) + + expert_fn = _diet_expert + else: + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + + batch_coordinate = dp(get_batch_coordinate, x) + + layers = hparams.layers.strip(",").split(",") + for layer_num, layer_type in enumerate(layers): + with tf.variable_scope("%s_%d" % (layer_type, layer_num)): + if _should_preprocess(layer_type): + x = preprocess(x) + if layer_type == "timing": + y = dp(common_attention.add_timing_signal_nd, x) + elif layer_type == "pos_emb": + y = dp(common_attention.add_positional_embedding_nd, + x, hparams.max_length, name="pos_emb") + elif layer_type == "att": + y = dp( + common_attention.multihead_attention, + x, + None, + None, # bias + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + elif layer_type == "att_memory_efficient": + assert hparams.layer_preprocess_sequence == "n" + zero_bias = tf.zeros([1, 1, 1, 1]) + y = dp( + common_attention.multihead_self_attention_memory_efficient, + x, + zero_bias, + hparams.num_heads) + elif layer_type == "att_local": + y = dp( + common_attention.multihead_attention, + x, + None, + None, # bias + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + attention_type="local_unmasked", + block_length=hparams.local_attention_window, + block_width=hparams.local_attention_window) + elif layer_type == "att_pseudolocal": + # This is an inefficient implementation of local attention, for the + # purpose of testing model quality. + def _pseudolocal_bias(x): + return common_attention.attention_bias_local( + tf.shape(x)[1], + hparams.local_attention_window, + hparams.local_attention_window) + pseudolocal_bias = dp(_pseudolocal_bias, x) + y = dp( + common_attention.multihead_attention, + x, + None, + pseudolocal_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout) + elif layer_type == "att_local_expert": + y, loss = dp( + common_attention.local_expert_attention, + x, + k=hparams.attention_moe_k, + loss_coef=hparams.attention_load_balance, + attention_num_experts=hparams.attention_num_experts, + train=hparams.mode == ModeKeys.TRAIN, + batch_coordinate=batch_coordinate, + mask_right=False, + split_batch=bool(hparams.attention_split_batch), + attention_kq_size=hparams.attention_kq_size, + attention_v_size=hparams.attention_v_size) + # TODO(avaswani, epot, noam): Do we need to divide by num shards ? + extra_loss += tf.add_n(loss) / dp.n + elif layer_type == "moe": + y, loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + x, + hparams.mode == ModeKeys.TRAIN, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) + extra_loss += loss + elif layer_type == "ffn": + y = dp( + expert_utils.ffn_expert_fn( + hparams.hidden_size, + ffn_hidden_sizes, + hparams.hidden_size), + dp(expert_utils.flatten_all_but_last, x)) + y = dp(expert_utils.reshape_like, y, x) + elif layer_type == "conv": + y = dp( + common_layers.conv1d, + x, + hparams.hidden_size, + hparams.kernel_height, + activation=tf.nn.relu, + padding="SAME", + ) + else: + assert False, "unknown sublayer %s" % layer_type + if _should_postprocess(layer_type): + x = postprocess(x, y) + else: + x = y + x = preprocess(x) + + decoder_output = dp(tf.expand_dims, x, 2) + return decoder_output, extra_loss + + +def get_batch_coordinate(x): + """Return a flat int32 tensor of shape [1, batch_size*length, 1].""" + # Compute the batch coordinate before flattening all batches + batch_coordinate = tf.expand_dims( + common_attention.coordinate_tensor(tf.shape(x)[:-1], axis=0), axis=-1) + return batch_coordinate + + +@registry.register_hparams +def aligned_base(): + """Set of hyperparameters. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps (10min): log(ppl)_eval = 2.60 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + + Returns: + a hparams object + """ + hparams = common_hparams.basic_params1() + hparams.hidden_size = 512 + hparams.batch_size = 5000 + hparams.max_length = 1024 + hparams.min_length_bucket = 1024 + hparams.dropout = 0.0 + hparams.layer_prepostprocess_dropout = 0.0 + hparams.label_smoothing = 0.0 + hparams.clip_grad_norm = 0. # i.e. no gradient clipping + hparams.optimizer_adam_epsilon = 1e-9 + hparams.learning_rate_decay_scheme = "noam" + hparams.learning_rate = 0.1 + hparams.learning_rate_warmup_steps = 2000 + hparams.initializer_gain = 1.0 + hparams.initializer = "uniform_unit_scaling" + hparams.weight_decay = 0.0 + hparams.optimizer_adam_beta1 = 0.9 + hparams.optimizer_adam_beta2 = 0.98 + hparams.shared_embedding_and_softmax_weights = int(True) + hparams.add_hparam("ffn_hidden_sizes", "2048") # Add new ones like this. + hparams.moe_num_experts = 32 + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.add_hparam("layers", "timing," + "conv,att,ffn," * 2) + + # attention-related flags + hparams.add_hparam("num_heads", 8) + hparams.add_hparam("attention_key_channels", 0) + hparams.add_hparam("attention_value_channels", 0) + # All hyperparameters ending in "dropout" are automatically set to 0.0 + # when not in training mode. + hparams.add_hparam("attention_dropout", 0.0) + hparams.add_hparam("pos", "timing") # timing, none + # moe params. local attention moe. + hparams.add_hparam("attention_local", int(False)) + hparams.add_hparam("attention_moe_k", 2) + hparams.add_hparam("attention_num_experts", 16) + hparams.add_hparam("attention_split_batch", int(False)) + # Key, query and value dimensions for the attention + hparams.add_hparam("attention_kq_size", 128) + hparams.add_hparam("attention_v_size", 256) + # Loss coef for load balancing + hparams.add_hparam("attention_load_balance", 2e-2) + hparams.add_hparam("diet_experts", int(False)) + hparams.add_hparam("memory_efficient_ffn", int(False)) + hparams.add_hparam("local_attention_window", 128) + # if True, we learn a non-autoregressive model from "inputs" to "targets". + # if False, we learn an autoregressive model to generate "targets" + return hparams + + +@registry.register_hparams +def aligned_memory_efficient(): + """Use multihead_self_attention_memory_efficient. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.59 + 8.7 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.02 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_memory_efficient,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_local_expert(): + """Use local_expert_attention. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.72 + 10.2 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.27 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_local_expert,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_local(): + """Use local attention code. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 12.8 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.08 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_local,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_local_1k(): + """Use local attention code, attend to full sequence. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 7.5 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + + Returns: + a hparams object + """ + hparams = aligned_local() + hparams.local_attention_window = 1024 + return hparams + + +@registry.register_hparams +def aligned_pseudolocal(): + """Use a bias to simulate local attention. attention radius 128. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.57 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.06 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_pseudolocal,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_pseudolocal_256(): + """Use a bias to simulate local attention. attentio radius 256. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.56 + 12.0 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.05 + + Returns: + a hparams object + """ + hparams = aligned_pseudolocal() + hparams.local_attention_window = 256 + return hparams + + +@registry.register_hparams +def aligned_no_timing(): + """No timing signal. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.75 + 12.3 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.39 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "conv,att,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_no_att(): + """No attention at all. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.89 + 20.8 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.70 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "conv,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_pos_emb(): + """positional embedding insead of timing signal. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.67 + 12.1 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.00 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "pos_emb," + "conv,att,ffn," * 2 + return hparams + + +@registry.register_hparams +def aligned_moe(): + """mixture of experts instead of ffn. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.62 + 6.7 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 1.94 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att,moe," * 2 + return hparams + + +@registry.register_hparams +def aligned_8k(): + """version for languagemodel_wiki_scramble8k50. + + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.93 + 1.5 steps/sec on P100 + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.max_length = 8192 + hparams.batch_size = 8192 + return hparams diff --git a/tensor2tensor/models/attention_lm.py b/tensor2tensor/models/attention_lm.py index 3302f45be..696057233 100644 --- a/tensor2tensor/models/attention_lm.py +++ b/tensor2tensor/models/attention_lm.py @@ -79,7 +79,7 @@ def attention_lm_prepare_decoder(targets, hparams): else: decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index adbb871b5..42a9fbabf 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -60,6 +60,13 @@ def get_choices(): ] +LAYER_SYMBOLS = { + "h": AttentionType.MULTIHEAD, # multi-Head + "e": AttentionType.LOCAL_EXPERTS, # Experts + "m": AttentionType.MEMORY_EFFICIENT, # Memory +} + + @registry.register_model class AttentionLmMoe(t2t_model.T2TModel): """Attention net. See file docstring.""" @@ -98,8 +105,7 @@ def _diet_expert(x): expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) - if (hparams.attention_type == AttentionType.LOCAL_EXPERTS - and not hparams.use_inputs): + if not hparams.use_inputs: # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) @@ -128,16 +134,23 @@ def print_shape(x, suffix, debug=False): batch_coordinate = dp_remove_pad(batch_coordinate) x = dp(print_shape, x, "in") - x = dp_remove_pad(x) - x = dp(print_shape, x, "in_flat") assert hparams.batch_size >= hparams.max_length - for layer in xrange(hparams.num_hidden_layers): + num_hidden_layers = ( + len(hparams.attention_layers) or hparams.num_hidden_layers) + for layer in xrange(num_hidden_layers): with tf.variable_scope("layer_%d" % layer): + + # Use the layer type defined in attention_layers + if hparams.attention_layers: + attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]] + else: + attention_type = hparams.attention_type + with tf.variable_scope( - "attention_{}".format(hparams.attention_type)): - if hparams.attention_type == AttentionType.MULTIHEAD: + "attention_{}".format(attention_type)): + if attention_type == AttentionType.MULTIHEAD: y = dp( common_attention.multihead_attention, preprocess(x), @@ -151,7 +164,7 @@ def print_shape(x, suffix, debug=False): attention_type=("local_mask_right" if hparams.attention_local else "dot_product"), name="decoder_self_attention") - elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT: + elif attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp( common_attention.multihead_self_attention_memory_efficient, @@ -159,10 +172,12 @@ def print_shape(x, suffix, debug=False): decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") - elif hparams.attention_type == AttentionType.LOCAL_EXPERTS: + elif attention_type == AttentionType.LOCAL_EXPERTS: + x_in = preprocess(x) + x_in = dp_remove_pad(x_in) y, loss = dp( common_attention.local_expert_attention, - preprocess(x), + x_in, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, @@ -172,6 +187,7 @@ def print_shape(x, suffix, debug=False): split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) + y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: @@ -198,15 +214,8 @@ def print_shape(x, suffix, debug=False): x, hparams.filter_size) else: - x_in = preprocess(x) additional_conv_params = dict() if hparams.use_sepconv: - # Restore padding so sequences don't attend to each others - # restore_pad will apply a reshape like x_ref, to restore the - # original shape. Here this works because the last dimension is - # constant between the output of attention and the original input - # but it shouldn't necessarily be the case. - x_in = dp_restore_pad(x_in) additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model @@ -215,19 +224,15 @@ def print_shape(x, suffix, debug=False): ) y = dp( common_layers.conv_hidden_relu, - x_in, + preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params ) - if hparams.use_sepconv: - y = dp_remove_pad(y) x = postprocess(x, y) x = preprocess(x) - x = dp_restore_pad(x) - decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss @@ -257,7 +262,7 @@ def attention_lm_moe_prepare_decoder(targets, hparams): common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) # TODO(epot): The padding remover should take into account that the input is # shifted. - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias, pad_remover) @@ -350,6 +355,10 @@ def attention_lm_moe_base(): hparams.add_hparam("pos", "timing") # timing, none hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers # moe params. local attention moe. + # If attention_layers is set, the num_hidden_layers parameter will be ignored + # and each caracter of the string will correspond to one attention + # layer type + hparams.add_hparam("attention_layers", "") hparams.add_hparam("attention_type", AttentionType.MULTIHEAD) hparams.add_hparam("attention_local", int(False)) hparams.add_hparam("attention_moe_k", 2) @@ -370,14 +379,24 @@ def attention_lm_moe_base(): @registry.register_hparams -def attention_lm_moe_base_ae(): - """Base model with attention expert.""" +def attention_lm_moe_base_long_seq(): + """Hyper parameters specifics for long sequence generation.""" hparams = attention_lm_moe_base() - hparams.attention_type = AttentionType.LOCAL_EXPERTS - hparams.use_sepconv = int(True) + hparams.max_length = 0 # max_length == batch_size hparams.eval_drop_long_sequences = int(True) hparams.min_length_bucket = 256 # Avoid cyclic problems for big batches + hparams.use_sepconv = int(True) + + return hparams + + +@registry.register_hparams +def attention_lm_moe_base_ae(): + """Base model with attention expert.""" + hparams = attention_lm_moe_base_long_seq() + hparams.attention_type = AttentionType.LOCAL_EXPERTS + hparams.learning_rate = 0.05 hparams.learning_rate_warmup_steps = 10000 # According to noam, ("n", "da") seems better for harder-to-learn models @@ -389,12 +408,20 @@ def attention_lm_moe_base_ae(): @registry.register_hparams def attention_lm_moe_base_local(): """Base model with attention expert.""" - hparams = attention_lm_moe_base() + hparams = attention_lm_moe_base_long_seq() hparams.attention_local = int(True) - hparams.use_sepconv = int(True) - hparams.max_length = 0 # max_length == batch_size - hparams.eval_drop_long_sequences = int(True) - hparams.min_length_bucket = 256 # Avoid cyclic problems for big batches + return hparams + + +@registry.register_hparams +def attention_lm_moe_base_hybrid(): + """Base model with attention expert.""" + hparams = attention_lm_moe_base_long_seq() + hparams.attention_layers = "hehe" # Alternate local/expert + hparams.attention_local = int(True) + + # hparams.layer_preprocess_sequence = "n" + # hparams.layer_postprocess_sequence = "da" return hparams diff --git a/tensor2tensor/models/bluenet_test.py b/tensor2tensor/models/bluenet_test.py index d559fd953..daf87529e 100644 --- a/tensor2tensor/models/bluenet_test.py +++ b/tensor2tensor/models/bluenet_test.py @@ -36,8 +36,7 @@ def testBlueNet(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) hparams = bluenet.bluenet_tiny() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: tf.train.get_or_create_global_step() features = { diff --git a/tensor2tensor/models/bytenet.py b/tensor2tensor/models/bytenet.py index e4537ef3f..5af0c4435 100644 --- a/tensor2tensor/models/bytenet.py +++ b/tensor2tensor/models/bytenet.py @@ -66,7 +66,7 @@ def bytenet_internal(inputs, targets, hparams): final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), diff --git a/tensor2tensor/models/bytenet_test.py b/tensor2tensor/models/bytenet_test.py index 56f421153..f96d3b999 100644 --- a/tensor2tensor/models/bytenet_test.py +++ b/tensor2tensor/models/bytenet_test.py @@ -36,8 +36,7 @@ def testByteNet(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = bytenet.bytenet_base() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), diff --git a/tensor2tensor/models/lstm.py b/tensor2tensor/models/lstm.py index d1c3101b4..f336bd6b4 100644 --- a/tensor2tensor/models/lstm.py +++ b/tensor2tensor/models/lstm.py @@ -221,7 +221,7 @@ def lstm_seq2seq_internal(inputs, targets, hparams, train): _, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder. - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), hparams, @@ -240,7 +240,7 @@ def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): encoder_outputs, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder with attention - shifted_targets = common_layers.shift_left(targets) + shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs) @@ -266,13 +266,20 @@ def model_fn_body(self, features): @registry.register_hparams -def lstm_attention(): - """hparams for LSTM with attention.""" +def lstm_seq2seq(): + """hparams for LSTM.""" hparams = common_hparams.basic_params1() hparams.batch_size = 1024 hparams.hidden_size = 128 hparams.num_hidden_layers = 2 hparams.initializer = "uniform_unit_scaling" + return hparams + + +@registry.register_hparams +def lstm_attention(): + """hparams for LSTM with attention.""" + hparams = lstm_seq2seq() # Attention hparams.add_hparam("attn_vec_size", hparams.hidden_size) diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index c1190d016..0d4bc6d80 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -37,8 +37,7 @@ def testLSTMSeq2Seq(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = common_hparams.basic_params1() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), @@ -58,8 +57,7 @@ def testLSTMSeq2SeqAttention(self): y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 6, 1, 1)) hparams = lstm.lstm_attention() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) x = tf.constant(x, dtype=tf.int32) x._shape = tf.TensorShape([None, None, 1, 1]) diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 5df8fcd3c..a4c82d942 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -99,7 +99,7 @@ def prepare_decoder(targets, target_space_emb): common_attention.attention_bias_lower_triangle(tf.shape(targets)[1])) target_space_emb = tf.reshape(target_space_emb, [1, 1, -1]) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1]) - decoder_input = common_layers.shift_left_3d( + decoder_input = common_layers.shift_right_3d( targets, pad_value=target_space_emb) decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) diff --git a/tensor2tensor/models/neural_gpu_test.py b/tensor2tensor/models/neural_gpu_test.py index 164623699..75149ddd5 100644 --- a/tensor2tensor/models/neural_gpu_test.py +++ b/tensor2tensor/models/neural_gpu_test.py @@ -39,7 +39,7 @@ def testNeuralGPU(self): target_length = input_length input_vocab_size = 9 target_vocab_size = 11 - p_hparams = problem_hparams.test_problem_hparams(hparams, input_vocab_size, + p_hparams = problem_hparams.test_problem_hparams(input_vocab_size, target_vocab_size) inputs = -1 + np.random.random_integers( input_vocab_size, size=(batch_size, input_length, 1, 1)) diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 6b07dc640..5377fd97e 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -198,7 +198,7 @@ def norm_fn(x, name): similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. - targets_shifted = common_layers.shift_left( + targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index a2e76dd13..7d4ce27be 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -41,34 +41,255 @@ class Transformer(t2t_model.T2TModel): """Attention net. See file docstring.""" + def encode(self, inputs, target_space, hparams): + """Encode transformer inputs. + + Args: + inputs: Transformer inputs [batch_size, input_length, hidden_dim] + target_space: scalar, target space ID. + hparams: hyperparmeters for model. + + Returns: + Tuple of: + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encodre-decoder attention. [batch_size, input_length] + """ + inputs = common_layers.flatten4d3d(inputs) + + encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( + transformer_prepare_encoder(inputs, target_space, hparams)) + + encoder_input = tf.nn.dropout( + encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) + + encoder_output = transformer_encoder( + encoder_input, + self_attention_bias, + hparams) + + return encoder_output, encoder_decoder_attention_bias + + def decode( + self, + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams, + cache=None): + """Decode Transformer outputs from encoder representation. + + Args: + decoder_input: inputs to bottom of the model. + [batch_size, decoder_length, hidden_dim] + encoder_output: Encoder representation. + [batch_size, input_length, hidden_dim] + encoder_decoder_attention_bias: Bias and mask weights for + encoder-decoder attention. [batch_size, input_length] + decoder_self_attention_bias: Bias and mask weights for decoder + self-attention. [batch_size, decoder_length] + hparams: hyperparmeters for model. + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. + + Returns: + Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + """ + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer_decoder( + decoder_input, + encoder_output, + decoder_self_attention_bias, + encoder_decoder_attention_bias, + hparams, + cache=cache) + + # Expand since t2t expects 4d tensors. + return tf.expand_dims(decoder_output, axis=2) + def model_fn_body(self, features): + """Transformet main model_fn. + + Args: + features: Map of features to the model. Should contain the following: + "inputs": Transformer inputs [batch_size, input_length, hidden_dim] + "tragets": Target decoder outputs. + [batch_size, decoder_length, hidden_dim] + "target_space_id" + + Returns: + Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + """ hparams = self._hparams - targets = features["targets"] + inputs = features["inputs"] + target_space = features["target_space_id"] + encoder_output, encoder_decoder_attention_bias = self.encode( + inputs, target_space, hparams) - inputs = common_layers.flatten4d3d(inputs) + targets = features["targets"] targets = common_layers.flatten4d3d(targets) - (encoder_input, encoder_self_attention_bias, - encoder_decoder_attention_bias) = transformer_prepare_encoder( - inputs, target_space, hparams) - (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( + decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams) - encoder_input = tf.nn.dropout(encoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = tf.nn.dropout(decoder_input, - 1.0 - hparams.layer_prepostprocess_dropout) - encoder_output = transformer_encoder(encoder_input, - encoder_self_attention_bias, hparams) - - decoder_output = transformer_decoder( - decoder_input, encoder_output, decoder_self_attention_bias, - encoder_decoder_attention_bias, hparams) - decoder_output = tf.expand_dims(decoder_output, 2) + return self.decode( + decoder_input, + encoder_output, + encoder_decoder_attention_bias, + decoder_self_attention_bias, + hparams) + + def _greedy_infer( + self, features, decode_length, last_position_only=True): + """Fast version of greedy decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + last_position_only: MUST be true for fast decoding! + + Returns: + samples: [batch_size, input_length + decode_length] + logits: Not returned + losses: Not returned + + Raises: + ValueError: If last_position_only if False + NotImplementedError: If there are multiple data shards. + """ + if not last_position_only: + raise ValueError("Fast decoding only deals with the last positions!") + if self._num_datashards != 1: + raise NotImplementedError("Fast decoding only supports a single shard.") + dp = self._data_parallelism + hparams = self._hparams - return decoder_output + inputs = features["inputs"] + batch_size = tf.shape(inputs)[0] + target_modality = self._problem_hparams.target_modality + if t2t_model.is_class_modality(target_modality): + decode_length = 1 + else: + decode_length = tf.shape(inputs)[1] + decode_length + + # TODO(llion): Clean up this reshaping logic. + inputs = tf.expand_dims(inputs, axis=1) + if len(inputs.shape) < 5: + inputs = tf.expand_dims(inputs, axis=4) + s = tf.shape(inputs) + inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]]) + # _shard_features called to ensure that the variable names match + inputs = self._shard_features({"inputs": inputs})["inputs"] + input_modality = self._problem_hparams.input_modality["inputs"] + with tf.variable_scope(input_modality.name): + inputs = input_modality.bottom_sharded(inputs, dp) + with tf.variable_scope("body"): + encoder_output, encoder_decoder_attention_bias = dp( + self.encode, inputs, features["target_space_id"], hparams) + + if hparams.pos == "timing": + timing_signal = common_attention.get_timing_signal_1d( + decode_length + 1, hparams.hidden_size) + + def preprocess_targets(targets, i): + """Performs preprocessing steps on the targets to prepare for the decoder. + + This includes: + - Embedding the ids. + - Flattening to 3D tensor. + - Optionally adding timing signals. + + Args: + targets: inputs ids to the decoder. [batch_size, 1] + i: scalar, Step number of the decoding loop. + + Returns: + Processed targets [batch_size, 1, hidden_dim] + """ + # _shard_features called to ensure that the variable names match + targets = self._shard_features({"targets": targets})["targets"] + with tf.variable_scope(target_modality.name): + targets = target_modality.targets_bottom_sharded(targets, dp)[0] + targets = common_layers.flatten4d3d(targets) + + # TODO(llion): Explain! Is this even needed? + targets = tf.cond( + tf.equal(i, 0), + lambda: tf.zeros_like(targets), + lambda: targets) + + if hparams.pos == "timing": + targets += timing_signal[:, i:i+1] + return targets + + decoder_self_attention_bias = ( + common_attention.attention_bias_lower_triangle(decode_length)) + if hparams.proximity_bias: + decoder_self_attention_bias += common_attention.attention_bias_proximal( + decode_length) + + def symbols_to_logits_fn(ids, i, cache): + """Go from ids to logits for next symbol.""" + targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) + targets = preprocess_targets(targets, i) + + bias = decoder_self_attention_bias[:, :, i:i+1, :i+1] + + with tf.variable_scope("body"): + body_outputs = dp( + self.decode, + targets, + encoder_output[0], + encoder_decoder_attention_bias[0], + bias, + hparams, + cache) + + with tf.variable_scope(target_modality.name): + logits = target_modality.top_sharded(body_outputs, None, dp)[0] + + return tf.squeeze(logits, axis=[1, 2, 3]) + + def inner_loop(i, next_id, decoded_ids, cache): + logits = symbols_to_logits_fn(next_id, i, cache) + next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) + decoded_ids = tf.concat([decoded_ids, next_id], axis=1) + return i+1, next_id, decoded_ids, cache + + key_channels = hparams.attention_key_channels or hparams.hidden_size + value_channels = hparams.attention_value_channels or hparams.hidden_size + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + + cache = { + "layer_%d" % layer: { + "k": tf.zeros([batch_size, 0, key_channels]), + "v": tf.zeros([batch_size, 0, value_channels]), + } for layer in range(num_layers) + } + decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) + next_id = tf.zeros([batch_size, 1], dtype=tf.int64) + _, _, decoded_ids, _ = tf.while_loop( + # TODO(llion): Early stopping. + lambda i, *_: tf.less(i, decode_length), + inner_loop, + [tf.constant(0), next_id, decoded_ids, cache], + shape_invariants=[ + tf.TensorShape([]), + tf.TensorShape([None, None]), + tf.TensorShape([None, None]), + {"layer_%d" % layer: { + "k": tf.TensorShape([None, None, key_channels]), + "v": tf.TensorShape([None, None, value_channels]), + } for layer in range(num_layers)} + ]) + + return decoded_ids, None, None @registry.register_model @@ -89,6 +310,7 @@ def model_fn_body(self, features): 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder(encoder_input, encoder_self_attention_bias, hparams) + encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output @@ -167,7 +389,7 @@ def transformer_prepare_decoder(targets, hparams): if hparams.proximity_bias: decoder_self_attention_bias += common_attention.attention_bias_proximal( tf.shape(targets)[1]) - decoder_input = common_layers.shift_left_3d(targets) + decoder_input = common_layers.shift_right_3d(targets) if hparams.pos == "timing": decoder_input = common_attention.add_timing_signal_1d(decoder_input) return (decoder_input, decoder_self_attention_bias) @@ -211,10 +433,10 @@ def transformer_encoder(encoder_input, y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams, pad_remover) x = common_layers.layer_postprocess(x, y, hparams) - # if normalization is done in layer_preprocess, then it shuold also be done - # on the output, since the output can grow very large, being the sum of - # a whole stack of unnormalized layer outputs. - return common_layers.layer_preprocess(x, hparams) + # if normalization is done in layer_preprocess, then it shuold also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + return common_layers.layer_preprocess(x, hparams) def transformer_decoder(decoder_input, @@ -222,6 +444,7 @@ def transformer_decoder(decoder_input, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, + cache=None, name="decoder"): """A stack of transformer layers. @@ -233,6 +456,8 @@ def transformer_decoder(decoder_input, encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention (see common_attention.attention_bias()) hparams: hyperparameters for model + cache: dict, containing tensors which are the results of previous + attentions, used for fast decoding. name: a string Returns: @@ -242,20 +467,28 @@ def transformer_decoder(decoder_input, with tf.variable_scope(name): for layer in xrange(hparams.num_decoder_layers or hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): + layer_name = "layer_%d" % layer + layer_cache = cache[layer_name] if cache is not None else None + with tf.variable_scope(layer_name): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( - common_layers.layer_preprocess( - x, hparams), None, decoder_self_attention_bias, + common_layers.layer_preprocess(x, hparams), + None, + decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: with tf.variable_scope("encdec_attention"): + # TODO(llion): Add caching. y = common_attention.multihead_attention( - common_layers.layer_preprocess( - x, hparams), encoder_output, encoder_decoder_attention_bias, + common_layers.layer_preprocess(x, hparams), + encoder_output, + encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, @@ -265,10 +498,10 @@ def transformer_decoder(decoder_input, y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) x = common_layers.layer_postprocess(x, y, hparams) - # if normalization is done in layer_preprocess, then it shuold also be done - # on the output, since the output can grow very large, being the sum of - # a whole stack of unnormalized layer outputs. - return common_layers.layer_preprocess(x, hparams) + # if normalization is done in layer_preprocess, then it shuold also be done + # on the output, since the output can grow very large, being the sum of + # a whole stack of unnormalized layer outputs. + return common_layers.layer_preprocess(x, hparams) def transformer_ffn_layer(x, hparams, pad_remover=None): @@ -654,6 +887,14 @@ def transformer_parameter_attention_b(): return hparams +@registry.register_hparams +def transformer_prepend(): + hparams = transformer_base() + hparams.prepend_mode = "prepend_inputs_masked_attention" + hparams.max_length = 0 + return hparams + + @registry.register_ranged_hparams("transformer_base") def transformer_base_range(rhp): """Small range of hyperparameters.""" diff --git a/tensor2tensor/models/transformer_revnet_test.py b/tensor2tensor/models/transformer_revnet_test.py index f9bc8cfb2..f61b88b5b 100644 --- a/tensor2tensor/models/transformer_revnet_test.py +++ b/tensor2tensor/models/transformer_revnet_test.py @@ -46,8 +46,7 @@ def testTransformer(self): target_length = 7 vocab_size = 9 hparams = transformer_revnet_test() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) hparams.problems = [p_hparams] inputs = -1 + np.random.random_integers( vocab_size, size=(batch_size, input_length, 1, 1)) diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 9e450a670..22848b249 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -32,16 +32,21 @@ BATCH_SIZE = 3 INPUT_LENGTH = 5 TARGET_LENGTH = 7 -VOCAB_SIZE = 9 +VOCAB_SIZE = 10 class TransformerTest(tf.test.TestCase): - def getModel(self): + def getModel(self, mode=tf.estimator.ModeKeys.TRAIN): hparams = transformer.transformer_small() - p_hparams = problem_hparams.test_problem_hparams( - hparams, VOCAB_SIZE, VOCAB_SIZE) + hparams.hidden_size = 8 + hparams.filter_size = 32 + hparams.num_heads = 1 + hparams.layer_prepostprocess_dropout = 0.0 + + p_hparams = problem_hparams.test_problem_hparams(VOCAB_SIZE, VOCAB_SIZE) hparams.problems = [p_hparams] + inputs = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, INPUT_LENGTH, 1, 1)) targets = -1 + np.random.random_integers( @@ -64,6 +69,39 @@ def testTransformer(self): res = session.run(logits) self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE)) + def testGreedyVsFast(self): + model, features = self.getModel() + + decode_length = 2 + + out_logits, _ = model.model_fn(features) + out_logits = tf.squeeze(out_logits[0], axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model, _ = self.getModel(tf.estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + greedy_result, _, _ = model._slow_greedy_infer( + features, decode_length, last_position_only=True) + greedy_result = tf.squeeze(greedy_result, axis=[2, 3]) + + fast_result, _, _ = model._greedy_infer(features, decode_length) + + with self.test_session(): + greedy_res = greedy_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(greedy_res, fast_res) if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index e3279495a..86950d6b7 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -187,7 +187,7 @@ def encode(x, x_space, hparams, name): def decode(cond_vec, cond_add, gold, c, ed, hparams): """Transformer decoder.""" drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = common_layers.shift_left(drop_gold, pad_value=cond_vec) + decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec) if cond_add is not None: decoder_input += cond_add decoder_input = tf.squeeze(decoder_input, axis=2) diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index eb4c6db20..9114fb781 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -36,8 +36,7 @@ def testXception(self): x = np.random.random_integers(1, high=vocab_size - 1, size=(3, 5, 1, 1)) y = np.random.random_integers(1, high=vocab_size - 1, size=(3, 1, 1, 1)) hparams = xception.xception_tiny() - p_hparams = problem_hparams.test_problem_hparams(hparams, vocab_size, - vocab_size) + p_hparams = problem_hparams.test_problem_hparams(vocab_size, vocab_size) with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), diff --git a/tensor2tensor/utils/avg_checkpoints.py b/tensor2tensor/utils/avg_checkpoints.py index 77acd4353..4d1c56eda 100644 --- a/tensor2tensor/utils/avg_checkpoints.py +++ b/tensor2tensor/utils/avg_checkpoints.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import os + # Dependency imports import numpy as np @@ -30,6 +32,9 @@ flags.DEFINE_string("checkpoints", "", "Comma-separated list of checkpoints to average.") +flags.DEFINE_integer("num_last_checkpoints", 0, + "Averages the last N saved checkpoints." + " If the checkpoints flag is set, this is ignored.") flags.DEFINE_string("prefix", "", "Prefix (e.g., directory) to append to each checkpoint.") flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", @@ -42,17 +47,32 @@ def checkpoint_exists(path): def main(_): - # Get the checkpoints list from flags and run some basic checks. - checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] - checkpoints = [c for c in checkpoints if c] - if not checkpoints: - raise ValueError("No checkpoints provided for averaging.") - if FLAGS.prefix: - checkpoints = [FLAGS.prefix + c for c in checkpoints] + if FLAGS.checkpoints: + # Get the checkpoints list from flags and run some basic checks. + checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] + checkpoints = [c for c in checkpoints if c] + if not checkpoints: + raise ValueError("No checkpoints provided for averaging.") + if FLAGS.prefix: + checkpoints = [FLAGS.prefix + c for c in checkpoints] + else: + assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model" + assert FLAGS.prefix, ("Prefix must be provided when averaging last" + " N checkpoints") + checkpoint_state = tf.train.get_checkpoint_state( + os.path.dirname(FLAGS.prefix)) + # Checkpoints are ordered from oldest to newest. + checkpoints = checkpoint_state.all_model_checkpoint_paths[ + -FLAGS.num_last_checkpoints:] + checkpoints = [c for c in checkpoints if checkpoint_exists(c)] if not checkpoints: - raise ValueError( - "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) + if FLAGS.checkpoints: + raise ValueError( + "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) + else: + raise ValueError("Could not find checkpoints at %s" % + os.path.dirname(FLAGS.prefix)) # Read variables from all checkpoints and average them. tf.logging.info("Reading variables and averaging checkpoints:") diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index be6c28559..c5e8eb85e 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -107,7 +107,6 @@ def beam_search(symbols_to_logits_fn, eos_id=EOS_ID): """Beam search with length penalties. - Uses an interface specific to the sequence cnn models; Requires a function that can take the currently decoded sybmols and return the logits for the next symbol. The implementation is inspired by https://arxiv.org/abs/1609.08144. diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 834e631ac..e88d208ac 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -29,8 +29,6 @@ from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import zip # pylint: disable=redefined-builtin -from tensor2tensor.data_generators import problem_hparams -from tensor2tensor.data_generators.problem import preprocess_examples_common from tensor2tensor.utils import registry import tensorflow as tf @@ -128,25 +126,6 @@ def decode_record(record): return dataset -def preprocessing(examples, data_file_pattern): - """Preprocessing of examples.""" - # This function is for obsolete problems only, as we're porting them - # all to the Problem class and its preprocess_examples method. Don't add. - if "audio" in data_file_pattern: - # Reshape audio to proper shape - sample_count = tf.to_int32(examples.pop("audio/sample_count")) - sample_width = tf.to_int32(examples.pop("audio/sample_width")) - channel_count = 1 - examples["inputs"] = tf.reshape(examples["inputs"], - [sample_count, sample_width, channel_count]) - if "wsj" in data_file_pattern: - examples["inputs"] = tf.bitcast(examples["inputs"], tf.int32) - elif "a2q_20161229" in data_file_pattern: - # we forgot the EOS when we preprocessed this data. - examples["targets"] = tf.concat([examples["targets"], [1]], 0) - return examples - - def cast_int64_to_int32(features): f = {} for k, v in six.iteritems(features): @@ -156,51 +135,30 @@ def cast_int64_to_int32(features): return f -def feature_placeholders(data_fields): - feature_map = {} - for (field, tp) in data_fields: - if not field.startswith("targets"): - feature_map[field] = tf.placeholder( - dtype=tp, shape=[None] * 4, name=field) - return feature_map - - -def default_example_reading_spec(data_file_pattern): - """Example reading spec for problem_hparams problems.""" - # This function is for problems that have yet to be ported to the new Problem - # API. Do not add here. - data_items_to_decoders = None - # Read from image TFRecords if the file has "image" in its name. - if data_file_pattern and "image" in data_file_pattern: - label_key = "image/class/label" - data_fields = { - "image/encoded": tf.FixedLenFeature((), tf.string), - "image/format": tf.FixedLenFeature((), tf.string), - label_key: tf.VarLenFeature(tf.int64) - } +def feature_placeholders(data_fields, data_items_to_decoders): + """Construct Placeholders and run decoders.""" + example = {} + for field, config in data_fields.items(): + if isinstance(config, tf.VarLenFeature): + shape = [None] + else: + shape = config.shape + + example[field] = tf.placeholder(dtype=config.dtype, shape=shape, name=field) + + # Decode + if data_items_to_decoders is None: data_items_to_decoders = { - "inputs": - tf.contrib.slim.tfexample_decoder.Image( - image_key="image/encoded", - format_key="image/format", - channels=1 if "mnist" in data_file_pattern else 3), - "targets": - tf.contrib.slim.tfexample_decoder.Tensor(label_key), + field: tf.contrib.slim.tfexample_decoder.Tensor(field) + for field in data_fields } - elif data_file_pattern and "audio" in data_file_pattern: - data_type = tf.int64 if "timit" in data_file_pattern else tf.float32 - data_fields = { - "inputs": tf.VarLenFeature(data_type), - "audio/sample_count": tf.FixedLenFeature((), tf.int64), - "audio/sample_width": tf.FixedLenFeature((), tf.int64), - "targets": tf.VarLenFeature(tf.int64), - } - else: - data_fields = { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64) - } - return data_fields, data_items_to_decoders + + decoded_example = {} + for field, decoder in data_items_to_decoders.items(): + keys_to_tensors = {key: example[key] for key in decoder.keys} + decoded_example[field] = decoder.tensors_to_item(keys_to_tensors) + + return decoded_example def read_examples(problem, @@ -208,15 +166,11 @@ def read_examples(problem, capacity, mode=tf.estimator.ModeKeys.TRAIN): """Create Dataset of Example for problem and data_file_pattern.""" - if problem is None: - data_fields, data_items_to_decoders = default_example_reading_spec( - data_file_pattern) - else: - data_fields, data_items_to_decoders = problem.example_reading_spec() + data_fields, data_items_to_decoders = problem.example_reading_spec() if data_file_pattern is None: # Create placeholders for input, rather than reading data from disk. - return feature_placeholders(data_fields) + return feature_placeholders(data_fields, data_items_to_decoders) is_training = mode == tf.estimator.ModeKeys.TRAIN dataset = examples_reader( @@ -255,7 +209,7 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, # reading, parsing, and preprocessing. Use Problem.dataset instead. dataset = read_examples(problem, data_file_pattern, capacity, mode=mode) dataset = dataset.map( - lambda ex: _preprocess(ex, problem, data_file_pattern, hparams, mode), + lambda ex: _preprocess(ex, problem, hparams, mode), num_threads=num_threads) dataset = dataset.filter( lambda ex: example_valid_size(ex, batching_scheme["max_length"])) @@ -285,14 +239,9 @@ def input_pipeline(problem, data_file_pattern, capacity, mode, hparams, return batched_examples -def _preprocess(example, problem, data_file_pattern, hparams, mode): +def _preprocess(example, problem, hparams, mode): """Preprocessing for example.""" - if problem is None: - example = preprocess_examples_common(example, hparams) - example = preprocessing(example, data_file_pattern) - else: - example = problem.preprocess_examples(example, mode, hparams) - + example = problem.preprocess_example(example, mode, hparams) # We do not want int64s as they are not supported on GPUs. example = cast_int64_to_int32(example) @@ -367,8 +316,8 @@ def batching_fn(bucket_id, grouped_dataset): if hasattr(dataset, "apply"): # If the Dataset supports dynamic window size, use it. dataset = dataset.apply( - tf.contrib.data.group_by_window, - args=(example_to_bucket_id, batching_fn, None, window_size_fn)) + tf.contrib.data.group_by_window(example_to_bucket_id, batching_fn, + None, window_size_fn)) else: dataset = dataset.group_by_window(example_to_bucket_id, batching_fn, window_size) @@ -384,7 +333,6 @@ def padded_batch(dataset, batch_size, padded_shapes=None): def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1): """A default set of length-bucket boundaries.""" - assert min_length <= max_length assert length_bucket_step > 1.0 x = min_length boundaries = [] @@ -511,13 +459,44 @@ def get_data_filepatterns(problems, data_dir, mode): """Return the location of a dataset for a given mode.""" datasets = [] for problem in problems.split("-"): - try: - problem = registry.problem(problem).dataset_filename() - except ValueError: - problem, _, _ = problem_hparams.parse_problem_name(problem) + problem = registry.problem(problem).dataset_filename() path = os.path.join(data_dir, problem) if mode == tf.estimator.ModeKeys.TRAIN: datasets.append("%s-train*" % path) else: datasets.append("%s-dev*" % path) return datasets + + +def serving_input_fn(problem, hparams): + """Input fn for serving, starting from Placeholders.""" + data_fields, data_items_to_decoders = problem.example_reading_spec() + + # Feature placeholders that mimic what's on disk + example = feature_placeholders(data_fields, data_items_to_decoders) + + # Preprocess + example = problem.preprocess_example(example, tf.estimator.ModeKeys.PREDICT, + hparams) + example = cast_int64_to_int32(example) + + # 4-D inputs and space ids + constants = {} + constants["target_space_id"] = tf.constant( + problem.get_hparams().target_space_id) + constants["problem_choice"] = tf.constant(0) + if problem.has_inputs: + while len(example["inputs"].get_shape()) != 4: + example["inputs"] = tf.expand_dims(example["inputs"], axis=-1) + constants["input_space_id"] = tf.constant( + problem.get_hparams().input_space_id) + example.pop("targets") + else: + while len(example["targets"].get_shape()) != 4: + example["targets"] = tf.expand_dims(example["targets"], axis=-1) + + features = constants + features.update(example) + + return tf.estimator.export.ServingInputReceiver( + features=features, receiver_tensors=example) diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index f03ce6da2..4f4d7530d 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -62,9 +62,9 @@ def example_reading_spec(self): data_items_to_decoders = None return (data_fields, data_items_to_decoders) - def preprocess_examples(self, examples, unused_mode, unused_hparams): - examples["new_field"] = tf.constant([42.42]) - return examples + def preprocess_example(self, example, unused_mode, unused_hparams): + example["new_field"] = tf.constant([42.42]) + return example def generate_test_data(problem, tmp_dir): @@ -143,10 +143,10 @@ def testTrainEvalBehavior(self): def testPreprocess(self): dataset = data_reader.read_examples(self.problem, self.filepatterns[0], 32) examples = dataset.make_one_shot_iterator().get_next() - examples = data_reader._preprocess(examples, self.problem, None, None, None) + examples = data_reader._preprocess(examples, self.problem, None, None) with tf.train.MonitoredSession() as sess: ex_val = sess.run(examples) - # problem.preprocess_examples has been run + # problem.preprocess_example has been run self.assertAllClose([42.42], ex_val["new_field"]) # int64 has been cast to int32 diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index d84fd740b..a08947202 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -47,7 +47,7 @@ def decode_hparams(overrides=""): save_images=False, problem_idx=0, extra_length=50, - batch_size=32, + batch_size=0, beam_size=4, alpha=0.6, return_beams=False, @@ -74,14 +74,18 @@ def log_decode_results(inputs, (problem_name, prediction_idx)) show_and_save_image(inputs / 255., save_path) elif inputs_vocab: - decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + if identity_output: + decoded_inputs = " ".join(map(str, inputs.flatten())) + else: + decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + tf.logging.info("Inference results INPUT: %s" % decoded_inputs) decoded_targets = None if identity_output: - decoded_outputs = "".join(map(str, outputs.flatten())) + decoded_outputs = " ".join(map(str, outputs.flatten())) if targets is not None: - decoded_targets = "".join(map(str, targets.flatten())) + decoded_targets = " ".join(map(str, targets.flatten())) else: decoded_outputs = "".join( map(str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) @@ -113,7 +117,8 @@ def decode_from_dataset(estimator, hparams=hparams, data_file_patterns=infer_problems_data, num_datashards=devices.data_parallelism().n, - fixed_problem=problem_idx) + fixed_problem=problem_idx, + batch_size=decode_hp.batch_size) # Get the predictions as an iterable predictions = estimator.predict(infer_input_fn) @@ -133,6 +138,7 @@ def decode_from_dataset(estimator, inputs_vocab = problem_hparams.vocabulary.get("inputs", None) targets_vocab = problem_hparams.vocabulary["targets"] for num_predictions, prediction in enumerate(predictions): + num_predictions += 1 inputs = prediction["inputs"] targets = prediction["targets"] outputs = prediction["outputs"] @@ -188,6 +194,11 @@ def decode_from_dataset(estimator, def decode_from_file(estimator, filename, decode_hp, decode_to_file=None): """Compute predictions on entries in filename and write them out.""" + if not decode_hp.batch_size: + decode_hp.batch_size = 32 + tf.logging.info( + "decode_hp.batch_size not specified; default=%d" % decode_hp.batch_size) + hparams = estimator.params problem_id = decode_hp.problem_idx inputs_vocab = hparams.problems[problem_id].vocabulary["inputs"] diff --git a/tensor2tensor/utils/devices.py b/tensor2tensor/utils/devices.py index d04b73563..d532b6d5f 100644 --- a/tensor2tensor/utils/devices.py +++ b/tensor2tensor/utils/devices.py @@ -109,7 +109,7 @@ def _replica_device_setter(worker_device): ps_tasks=FLAGS.ps_replicas, ps_device=FLAGS.ps_job + "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job) - if FLAGS.schedule == "local_run": + if FLAGS.schedule == "train_and_evaluate": assert not FLAGS.sync datashard_devices = ["gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)] if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index cfa782e8d..c9dde1a14 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -34,7 +34,8 @@ def build_input_fn(mode, num_datashards=None, fixed_problem=None, worker_replicas=None, - worker_id=None): + worker_id=None, + batch_size=None): """Provides input to the graph, either from disk or via a placeholder. This function produces an input function that will feed data into @@ -61,6 +62,7 @@ def build_input_fn(mode, setting with hparams.problem_choice == distributed. worker_id: int, id of this worker replica. Used in multiproblem setting with hparams.problem_choice == distributed. + batch_size: int, if provided, will use a fixed batch size. Returns: A function that returns a dictionary of features and the target labels. @@ -98,6 +100,7 @@ def input_fn(): problem_filepatterns, num_datashards, mode, + batch_size=batch_size, name="problem_%d" % problem_idx) problem_batches.append(feature_map) @@ -127,16 +130,18 @@ def input_fn(): feature_map["problem_choice"] = problem_choice # Set shapes so the ranks are clear. - feature_map["inputs"].set_shape([None, None, None, None]) + if problem_instance.has_inputs: + feature_map["inputs"].set_shape([None, None, None, None]) + feature_map["input_space_id"].set_shape([]) feature_map["targets"].set_shape([None, None, None, None]) feature_map["problem_choice"].set_shape([]) - feature_map["input_space_id"].set_shape([]) feature_map["target_space_id"].set_shape([]) if mode == tf.estimator.ModeKeys.PREDICT: feature_map["infer_targets"] = feature_map["targets"] # Forced shape obfuscation is necessary for inference. - feature_map["inputs"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access + if problem_instance.has_inputs: + feature_map["inputs"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access feature_map["targets"]._shape = tf.TensorShape([None, None, None, None]) # pylint: disable=protected-access # This is because of a bug in the Estimator that short-circuits prediction @@ -209,19 +214,25 @@ def features_for_problem(problem_instance, data_filepatterns, num_datashards, mode, + batch_size=None, name="problem_inputs"): """Feature map for Problem.""" with tf.name_scope(name): with tf.device("/cpu:0"): # Input reading on CPU capacity = (p_hparams.max_expected_batch_size_per_shard * num_datashards) + batching_scheme = data_reader.hparams_to_batching_scheme( + hparams, + shard_multiplier=num_datashards, + drop_long_sequences=(mode == tf.estimator.ModeKeys.TRAIN or + hparams.eval_drop_long_sequences), + length_multiplier=(p_hparams.batch_size_multiplier)) + if batch_size: + # If batch_size is fixed, use a single input bucket + batching_scheme["batch_sizes"] = [batch_size] + batching_scheme["boundaries"] = [] feature_map = data_reader.input_pipeline( problem_instance, data_filepatterns, capacity, mode, hparams, - data_reader.hparams_to_batching_scheme( - hparams, - shard_multiplier=num_datashards, - drop_long_sequences=(mode == tf.estimator.ModeKeys.TRAIN or - hparams.eval_drop_long_sequences), - length_multiplier=(p_hparams.batch_size_multiplier))) + batching_scheme) # Reverse inputs and targets features if the problem was reversed. if problem_instance is not None: @@ -238,11 +249,13 @@ def features_for_problem(problem_instance, feature_map["targets"] = feature_map["inputs"] # Ensure inputs and targets are proper rank. - while len(feature_map["inputs"].get_shape()) != 4: - feature_map["inputs"] = tf.expand_dims(feature_map["inputs"], axis=-1) + if problem_instance.has_inputs: + while len(feature_map["inputs"].get_shape()) != 4: + feature_map["inputs"] = tf.expand_dims(feature_map["inputs"], axis=-1) while len(feature_map["targets"].get_shape()) != 4: feature_map["targets"] = tf.expand_dims(feature_map["targets"], axis=-1) - feature_map["input_space_id"] = tf.constant(p_hparams.input_space_id) + if problem_instance.has_inputs: + feature_map["input_space_id"] = tf.constant(p_hparams.input_space_id) feature_map["target_space_id"] = tf.constant(p_hparams.target_space_id) return feature_map diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 7c4172743..4a4717bd4 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -50,9 +50,7 @@ def model_fn(model, worker_id=0, worker_replicas=1, eval_run_autoregressive=False, - decode_hparams=None, - autotune=False, - objective=None): + decode_hparams=None): """Builds the model for all modes. * TRAIN: Constructs loss and train_op @@ -72,8 +70,6 @@ def model_fn(model, worker_replicas: int, number of workers. eval_run_autoregressive: bool, whether to run evaluation autoregressively. decode_hparams: HParams for decode settings. Used when mode == PREDICT. - autotune: bool, whether this model is being used for autotuning. - objective: str, the objective if autotune==True. Returns: tf.estimator.EstimatorSpec @@ -186,15 +182,23 @@ def nth_model(n): "problem_choice": batched_problem_choice, } _del_dict_nones(predictions) - return tf.estimator.EstimatorSpec(mode, predictions=predictions) + + export_out = {"outputs": predictions["outputs"]} + if "scores" in predictions: + export_out["scores"] = predictions["scores"] + + return tf.estimator.EstimatorSpec( + mode, + predictions=predictions, + export_outputs={ + "output": tf.estimator.export.PredictOutput(export_out) + }) total_loss, logits = model_output if mode == tf.estimator.ModeKeys.EVAL: eval_metrics_fns = metrics.create_evaluation_metrics( zip(problem_names, hparams.problem_instances), hparams) - _check_autotune_metrics( - eval_metrics_fns, autotune=autotune, objective=objective) eval_metrics = {} for metric_name, metric_fn in six.iteritems(eval_metrics_fns): @@ -391,15 +395,6 @@ def _exp_decay_after(step, rate, from_which_step): name="exponential_decay_step_cond") -def _check_autotune_metrics(metrics_dict, autotune=False, objective=None): - if not autotune: - return - - if objective not in metrics_dict: - raise ValueError("Tuning objective %s not among evaluation metrics %s" % - (objective, metrics_dict.keys())) - - def _log_variable_sizes(var_list, tag): """Log the sizes and shapes of variables, and the total size. diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 32627f7e3..3fc110ebf 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -44,7 +44,7 @@ def fn_with_timing(*args, **kwargs): return fn_with_timing -def _is_class_modality(mod): +def is_class_modality(mod): # TODO(lukaszkaiser): should be based on type, like CLASS_LABEL, not string. prefix = "class_label_modality_" if len(mod.name) < len(prefix): @@ -198,7 +198,7 @@ def infer(self, # generated sequences, than to see the most likely sequence repeatedly. beam_size = 1 self._hparams.sampling_method = "random" - if _is_class_modality( + if is_class_modality( self._hparams.problems[self._problem_idx].target_modality): beam_size = 1 # No use to run beam-search for a single class. if beam_size == 1: @@ -228,10 +228,19 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, samples: an integer `Tensor`. Top samples from the beam search """ + batch_size = tf.shape(features["inputs"])[0] + batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=") + def symbols_to_logits_fn(ids): """Go from ids to logits.""" ids = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) ids = tf.pad(ids[:, 1:], [[0, 0], [0, 1], [0, 0], [0, 0]]) + if "partial_targets" in features: + pt = features["partial_targets"] + pt_length = tf.shape(pt)[1] + pt = tf.tile(pt, [1, beam_size]) + pt = tf.reshape(pt, [batch_size * beam_size, pt_length, 1, 1]) + ids = tf.concat([pt, ids], axis=1) features["targets"] = ids self._coverage = None @@ -247,7 +256,6 @@ def symbols_to_logits_fn(ids): logits = logits[:, current_output_position, :, :] return tf.squeeze(logits, axis=[1, 2]) - batch_size = tf.shape(features["inputs"])[0] initial_ids = tf.zeros([batch_size], dtype=tf.int32) inputs_old = features["inputs"] @@ -263,7 +271,9 @@ def symbols_to_logits_fn(ids): target_modality = self._hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality # Setting decode length to input length + decode_length - decode_length = tf.shape(features["inputs"])[1] + tf.constant(decode_length) + decode_length = tf.constant(decode_length) + if "partial_targets" not in features: + decode_length += tf.shape(features["inputs"])[1] ids, scores = beam_search.beam_search(symbols_to_logits_fn, initial_ids, beam_size, decode_length, vocab_size, alpha) @@ -282,7 +292,24 @@ def symbols_to_logits_fn(ids): return {"outputs": ids[:, :top_beams, 1:], "scores": scores} return ids[:, :top_beams, 1:] - def _greedy_infer(self, features, decode_length, last_position_only): + def _greedy_infer(self, features, decode_length, last_position_only): + """A greedy inference method. + + Models should ideally implement a more efficient version of this function. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + last_position_only: a boolean, speed-up by computing last position only. + + Returns: + samples: an integer `Tensor`. + logits: `Tensor` of shape [batch_size, time, 1, 1, vocab_size]. + losses: a dictionary: {loss-name (string): floating point `Scalar`} + """ + return self._slow_greedy_infer(features, decode_length, last_position_only) + + def _slow_greedy_infer(self, features, decode_length, last_position_only): """A slow greedy inference method. Quadratic time in decode_length. @@ -333,7 +360,9 @@ def infer_step(recent_output, recent_logits, unused_loss): # Create an initial output tensor. This will be passed # to the infer_step, which adds one timestep at every iteration. if "partial_targets" in features: - initial_output = tf.convert_to_tensor(features["partial_targets"]) + initial_output = tf.to_int64(tf.expand_dims( + tf.expand_dims(features["partial_targets"], 2), 3)) + batch_size = tf.shape(initial_output)[0] else: batch_size = tf.shape(features["inputs"])[0] initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64) @@ -342,7 +371,7 @@ def infer_step(recent_output, recent_logits, unused_loss): initial_output = tf.slice(initial_output, [0, 0, 0, 0], tf.shape(initial_output)) target_modality = self._hparams.problems[self._problem_idx].target_modality - if _is_class_modality(target_modality): + if is_class_modality(target_modality): decode_length = 1 else: decode_length = tf.shape(features["inputs"])[1] + decode_length @@ -366,6 +395,10 @@ def infer_step(recent_output, recent_logits, unused_loss): if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old losses = {"training": loss} + if "partial_targets" in features: + partial_target_length = tf.shape(features["partial_targets"])[1] + result = tf.slice( + result, [0, partial_target_length, 0, 0], [-1, -1, -1, -1]) return result, logits, losses def sample(self, features, last_position_only=False): @@ -464,6 +497,9 @@ def model_fn(self, features, skip=False, last_position_only=False): transformed_features["targets"] = target_modality.targets_bottom_sharded( sharded_features["targets"], dp) + # Allows later access to pre-embedding raw targets. + transformed_features["raw_targets"] = sharded_features["targets"] + # Construct the model body. with tf.variable_scope("body", reuse=self._problem_idx > 0): if skip: diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index 8ed7fb678..09c86ca09 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -26,7 +26,6 @@ from tensor2tensor import models # pylint: disable=unused-import from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import -from tensor2tensor.data_generators import problem_hparams from tensor2tensor.utils import data_reader from tensor2tensor.utils import decoding from tensor2tensor.utils import devices @@ -35,6 +34,7 @@ from tensor2tensor.utils import registry import tensorflow as tf +from tensorflow.contrib.hooks.python.training.profiler_hook import ProfilerHook from tensorflow.contrib.learn.python.learn import learn_runner from tensorflow.python import debug @@ -45,7 +45,10 @@ "If True, logs the contents of the registry and exits.") flags.DEFINE_bool("tfdbg", False, "If True, use the TF debugger CLI on train/eval.") -flags.DEFINE_string("output_dir", "", "Base output directory for run.") +flags.DEFINE_bool("export_saved_model", False, + "Whether to export a SavedModel for serving.") +flags.DEFINE_bool("dbgprofile", False, + "If True, record the timeline for chrome://tracing/.") flags.DEFINE_string("model", "", "Which model to use.") flags.DEFINE_string("hparams_set", "", "Which parameters to use.") flags.DEFINE_string("hparams_range", "", "Parameters range.") @@ -61,7 +64,6 @@ flags.DEFINE_string("data_dir", "/tmp/data", "Directory with training data.") flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") -flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") flags.DEFINE_bool("eval_run_autoregressive", False, "Run eval autoregressively where we condition on previous" "generated output instead of the actual target.") @@ -80,9 +82,6 @@ "Whether to log device placement.") # Distributed training flags -flags.DEFINE_string("master", "", "Address of TensorFlow master.") -flags.DEFINE_string("schedule", "local_run", - "Method of tf.contrib.learn.Experiment to run.") flags.DEFINE_integer("local_eval_frequency", 2000, "Run evaluation every this steps during local training.") flags.DEFINE_bool("locally_shard_to_cpu", False, @@ -91,7 +90,7 @@ flags.DEFINE_bool("daisy_chain_variables", True, "copy variables around in a daisy chain") flags.DEFINE_bool("sync", False, "Sync compute on PS.") -flags.DEFINE_string("worker_job", "/job:worker", "name of worker job") +flags.DEFINE_string("worker_job", "/job:localhost", "name of worker job") flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.") flags.DEFINE_integer("worker_replicas", 1, "How many workers to use.") flags.DEFINE_integer("worker_id", 0, "Which worker task are we.") @@ -113,35 +112,51 @@ def make_experiment_fn(data_dir, model_name, train_steps, eval_steps): """Returns experiment_fn for learn_runner. Wraps create_experiment.""" - def experiment_fn(output_dir): + def experiment_fn(run_config, hparams): return create_experiment( - output_dir=output_dir, - data_dir=data_dir, + data_dir, model_name=model_name, train_steps=train_steps, - eval_steps=eval_steps) + eval_steps=eval_steps, + hparams=hparams, + run_config=run_config) return experiment_fn -def create_experiment(output_dir, data_dir, model_name, train_steps, - eval_steps): +def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, + run_config): """Create Experiment.""" - hparams = create_hparams( - FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams) - if FLAGS.worker_id == 0 and FLAGS.schedule in ["local_run", "train"]: - save_metadata(output_dir, hparams) estimator, input_fns = create_experiment_components( - hparams=hparams, - output_dir=output_dir, data_dir=data_dir, - model_name=model_name) + model_name=model_name, + hparams=hparams, + run_config=run_config) + train_monitors = [] eval_hooks = [] if FLAGS.tfdbg: hook = debug.LocalCLIDebugHook() train_monitors.append(hook) eval_hooks.append(hook) + if FLAGS.dbgprofile: + # Recorded traces can be visualized with chrome://tracing/ + # The memory/tensor lifetime is also profiled + train_monitors.append(ProfilerHook( + save_steps=10, + output_dir=run_config.model_dir, + show_dataflow=True, + show_memory=True, + )) + + optional_kwargs = {} + if FLAGS.export_saved_model: + assert len(hparams.problem_instances) == 1 + problem = hparams.problem_instances[0] + optional_kwargs["export_strategies"] = [ + make_export_strategy(problem, hparams) + ] + return tf.contrib.learn.Experiment( estimator=estimator, train_input_fn=input_fns[tf.estimator.ModeKeys.TRAIN], @@ -150,12 +165,21 @@ def create_experiment(output_dir, data_dir, model_name, train_steps, eval_steps=eval_steps, min_eval_frequency=FLAGS.local_eval_frequency, train_monitors=train_monitors, - eval_hooks=eval_hooks) + eval_hooks=eval_hooks, + **optional_kwargs) + +def make_export_strategy(problem, hparams): + return tf.contrib.learn.make_export_strategy( + lambda: data_reader.serving_input_fn(problem, hparams), as_text=True) -def create_experiment_components(hparams, output_dir, data_dir, model_name): + +def create_experiment_components(data_dir, model_name, hparams, run_config): """Constructs and returns Estimator and train/eval input functions.""" - tf.logging.info("Creating experiment, storing model files in %s", output_dir) + tf.logging.info("Creating experiment, storing model files in %s", + run_config.model_dir) + + hparams = add_problem_hparams(hparams, FLAGS.problems) num_datashards = devices.data_parallelism().n train_input_fn = input_fn_builder.build_input_fn( @@ -176,11 +200,6 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): worker_replicas=FLAGS.worker_replicas, worker_id=FLAGS.worker_id) - autotune = False - objective = None - if hasattr(FLAGS, "autotune"): - autotune = FLAGS.autotune - objective = FLAGS.objective model_fn = model_builder.build_model_fn( model_name, problem_names=FLAGS.problems.split("-"), @@ -188,20 +207,13 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name): worker_id=FLAGS.worker_id, worker_replicas=FLAGS.worker_replicas, eval_run_autoregressive=FLAGS.eval_run_autoregressive, - decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), - autotune=autotune, - objective=objective) + decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams)) + estimator = tf.estimator.Estimator( model_fn=model_fn, - model_dir=output_dir, + model_dir=run_config.model_dir, params=hparams, - config=tf.contrib.learn.RunConfig( - master=FLAGS.master, - gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, - session_config=session_config(), - keep_checkpoint_max=FLAGS.keep_checkpoint_max, - keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, - save_checkpoints_secs=FLAGS.save_checkpoints_secs)) + config=run_config) return estimator, { tf.estimator.ModeKeys.TRAIN: train_input_fn, @@ -223,24 +235,12 @@ def add_problem_hparams(hparams, problems): try: problem = registry.problem(problem_name) except LookupError: - problem = None - - if problem is None: - try: - p_hparams = problem_hparams.problem_hparams(problem_name, hparams) - except LookupError: - # The problem is not in the set of registered Problems nor in the old - # set of problem_hparams. - all_problem_names = sorted( - list(problem_hparams.PROBLEM_HPARAMS_MAP) + - registry.list_problems()) - error_lines = [ - "%s not in the set of supported problems:" % problem_name - ] + all_problem_names - error_msg = "\n * ".join(error_lines) - raise LookupError(error_msg) - else: - p_hparams = problem.get_hparams(hparams) + all_problem_names = sorted(registry.list_problems()) + error_lines = ["%s not in the set of supported problems:" % problem_name + ] + all_problem_names + error_msg = "\n * ".join(error_lines) + raise LookupError(error_msg) + p_hparams = problem.get_hparams(hparams) hparams.problem_instances.append(problem) hparams.problems.append(p_hparams) @@ -279,7 +279,7 @@ def save_metadata(output_dir, hparams): f.write(hparams.to_json()) -def create_hparams(params_id, problems, data_dir, passed_hparams=None): +def create_hparams(params_id, data_dir, passed_hparams=None): """Returns hyperparameters, including any flag value overrides. If the hparams FLAG is set, then it will use any values specified in @@ -288,7 +288,6 @@ def create_hparams(params_id, problems, data_dir, passed_hparams=None): Args: params_id: which set of parameters to choose (must be in _PARAMS above). - problems: the string with problem names to get problem_hparams from. data_dir: the directory containing the training data. passed_hparams: command-line overrides for some hparams. @@ -301,16 +300,26 @@ def create_hparams(params_id, problems, data_dir, passed_hparams=None): if passed_hparams: hparams = hparams.parse(passed_hparams) - return add_problem_hparams(hparams, problems) + return hparams -def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): - """Runs an Estimator locally or distributed. +def create_run_config(output_dir): + """Create a RunConfig object.""" + + run_config = tf.contrib.learn.RunConfig( + model_dir=output_dir, + master=FLAGS.master, + gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction, + session_config=session_config(), + keep_checkpoint_max=FLAGS.keep_checkpoint_max, + keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours, + save_checkpoints_secs=FLAGS.save_checkpoints_secs) - This function chooses one of two paths to execute: + return run_config - 1. Running locally if schedule=="local_run". - 3. Distributed training/evaluation otherwise. + +def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): + """Runs an Estimator locally or distributed. Args: data_dir: The directory the data can be found in. @@ -327,22 +336,19 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): train_steps=train_steps, eval_steps=eval_steps) - if schedule == "local_run": - # Run the local demo. - exp = exp_fn(output_dir) - if exp.train_steps > 0 and exp.eval_steps > 0: - tf.logging.info("Performing local training and evaluation.") - exp.train_and_evaluate() - elif exp.train_steps > 0: - tf.logging.info("Performing local training.") - exp.train() - elif exp.eval_steps > 0: - tf.logging.info("Performing local evaluation.") - exp.evaluate(delay_secs=0) - else: - # Perform distributed training/evaluation. - learn_runner.run( - experiment_fn=exp_fn, schedule=schedule, output_dir=output_dir) + # Create hparams and run_config + run_config = create_run_config(output_dir) + hparams = create_hparams( + FLAGS.hparams_set, data_dir, passed_hparams=FLAGS.hparams) + + if is_chief(): + save_metadata(output_dir, hparams) + + learn_runner.run( + experiment_fn=exp_fn, + schedule=schedule, + run_config=run_config, + hparams=hparams) def validate_flags(): @@ -360,6 +366,11 @@ def validate_flags(): "Using default output_dir=%s.", FLAGS.output_dir) +def is_chief(): + schedules = ["train", "train_and_evaluate"] + return FLAGS.worker_id == 0 and FLAGS.schedule in schedules + + def session_config(): """The TensorFlow Session config to use.""" graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions( diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 6045dd2e0..16a8149f4 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -33,8 +33,14 @@ import tensorflow as tf +flags = tf.flags FLAGS = tf.flags.FLAGS +flags.DEFINE_string("schedule", "train_and_evaluate", "") +flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_string("master", "", "Address of TensorFlow master.") +flags.DEFINE_string("output_dir", "", "Base output directory for run.") + @registry.register_problem class TinyAlgo(algorithmic.AlgorithmicIdentityBinary40): @@ -84,13 +90,17 @@ def testHParamsImported(self): def testSingleStep(self): model_name = "transformer" - FLAGS.hparams_set = "transformer_test" + data_dir = TrainerUtilsTest.data_dir + hparams = trainer_utils.create_hparams("transformer_test", data_dir) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) exp = trainer_utils.create_experiment( - output_dir=tf.test.get_temp_dir(), - data_dir=TrainerUtilsTest.data_dir, + data_dir=data_dir, model_name=model_name, train_steps=1, - eval_steps=1) + eval_steps=1, + hparams=hparams, + run_config=trainer_utils.create_run_config( + output_dir=tf.test.get_temp_dir())) exp.test() def testSingleEvalStepRawSession(self): @@ -104,8 +114,8 @@ def testSingleEvalStepRawSession(self): # Create the problem object, hparams, placeholders, features dict. encoders = registry.problem(FLAGS.problems).feature_encoders(data_dir) - hparams = trainer_utils.create_hparams(FLAGS.hparams_set, FLAGS.problems, - data_dir) + hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir) + hparams = trainer_utils.add_problem_hparams(hparams, FLAGS.problems) inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. # In INFER mode targets can be None.