Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a852994
Attention moe can mix attention layer types
Sep 11, 2017
017f83a
Bug fixes in masked_local_attention_2d and local_attention_2d. We nee…
Sep 11, 2017
b5db405
Fix the pad_remover for attention expert when hybrid attention layers…
Sep 11, 2017
15682d5
Added a new model "aligned" for aligned sequence problems without aut…
nshazeer Sep 11, 2017
a7c7087
Added tests for 2-d local attention. Refactoring to use dot_product_a…
Sep 13, 2017
4f07375
use the right value for shape
Sep 13, 2017
802b95f
Separate out encoding a decoding steps.
a-googler Sep 13, 2017
466ce80
Split out timing signal function.
a-googler Sep 13, 2017
79ba4a8
Adding has_inputs property to Problem.
a-googler Sep 13, 2017
7035ffe
Allowing explicit timing positions to be used, by adding function add…
a-googler Sep 13, 2017
2138599
fix off-by-one num_samples bug in decode_from_dataset
Sep 14, 2017
3237538
Use decode_hparams.batch_size when decoding from dataset
Sep 14, 2017
e6e4263
Add wiki_scramble_128 dataset.
nshazeer Sep 15, 2017
6cb0bc8
Add ability to average the last N checkpoints, without needing to spe…
a-googler Sep 15, 2017
be19196
Working on a model for cnn_dailymail summarization task. Make greed…
nshazeer Sep 15, 2017
6970dea
Change ptb data generator to encode end of sentences with <EOS> tags …
MechCoder Sep 18, 2017
3aa1368
Rename ambiguous function names.
a-googler Sep 18, 2017
558fe96
Move the final layer_preprocess in the encoder and decoder in to the …
a-googler Sep 18, 2017
1e712d3
More experiments with "aligned" model and wiki_scramble dataset.
nshazeer Sep 18, 2017
1c7d365
Initial version of fast decoding for transformer models.
a-googler Sep 18, 2017
aa40c4b
Update experiment function signature to (run_config, hparams)
a-googler Sep 19, 2017
aec87db
[tf.contrib.data] Standardize transformation functions for use with `…
a-googler Sep 19, 2017
0b8573c
@recompute_grad decorator
Sep 19, 2017
77e91f6
Register `lstm_seq2seq` hparams.
a-googler Sep 19, 2017
12126bd
Add flag to profile ops/memory
Sep 19, 2017
9d63460
Enable fast decoding.
a-googler Sep 19, 2017
bc191b5
Fix formatting in identity output
Sep 19, 2017
f07b59f
Fix output shape of TransformerEncoder
Sep 20, 2017
21b3b55
SavedModel export and decoding fixes
Sep 20, 2017
620d6a5
Add Travis build shield to README
Sep 20, 2017
4280f44
Rm all refs to local_run in favor of train_and_evaluate
Sep 20, 2017
0841742
Support class modality in fast decoding.
a-googler Sep 21, 2017
09f1f17
Minimally port remaining problems to Problem class
Sep 21, 2017
f191c78
Correct README for decoding
Sep 21, 2017
c996878
Reproduces a bug with the SubwordTextEncoder in a test.
a-googler Sep 21, 2017
8ee8350
v1.2.3
Sep 21, 2017
e892dc3
Update example_life.md
Sep 22, 2017
6237729
Fix travis shield link
Sep 22, 2017
76706ef
Make output of fn in @recompute_grad a list to avoid trying to concat…
Sep 22, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](http
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
[![Travis](https://img.shields.io/travis/tensorflow/tensor2tensor.svg)](https://travis-ci.org/tensorflow/tensor2tensor)

[T2T](https://github.com/tensorflow/tensor2tensor) is a modular and extensible
library and binaries for supervised learning with TensorFlow and with support
Expand Down Expand Up @@ -123,8 +124,7 @@ t2t-decoder \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--decode_beam_size=$BEAM_SIZE \
--decode_alpha=$ALPHA \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE

cat $DECODE_FILE.$MODEL.$HPARAMS.beam$BEAM_SIZE.alpha$ALPHA.decodes
Expand Down
195 changes: 179 additions & 16 deletions docs/example_life.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,189 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

This document show how a training example passes through the T2T pipeline,
and how all its parts are connected to work together.
This doc explains how a training example flows through T2T, from data generation
to training, evaluation, and decoding. It points out the various hooks available
in the `Problem` and `T2TModel` classes and gives an overview of the T2T code
(key functions, files, hyperparameters, etc.).

## The Life of an Example
Some key files and their functions:

A training example passes the following stages in T2T:
* raw input (text from command line or file)
* encoded input after [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `encode` is usually a sparse tensor, e.g., a vector of `tf.int32`s
* batched input after [data input pipeline](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/data_reader.py#L242) where the inputs, after [Problem.preprocess_examples](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L188) are grouped by their length and made into batches.
* dense input after being processed by a [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `bottom`.
* dense output after [T2T.model_fn_body](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/t2t_model.py#L542)
* back to sparse output through [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) function `top`.
* if decoding, back through [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) function `decode` to display on the screen.
* [`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py):
Constructs and runs all the main components of the system (the `Problem`,
the `HParams`, the `Estimator`, the `Experiment`, the `input_fn`s and
`model_fn`).
* [`common_hparams.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/common_hparams.py):
`basic_params1` serves as the base for all model hyperparameters. Registered
model hparams functions always start with this default set of
hyperparameters.
* [`problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py):
Every dataset in T2T subclasses `Problem`.
* [`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py):
Every model in T2T subclasses `T2TModel`.

We go into these phases step by step below.
## Data Generation

## Feature Encoders
The `t2t-datagen` binary is the entrypoint for data generation. It simply looks
up the `Problem` specified by `--problem` and calls
`Problem.generate_data(data_dir, tmp_dir)`.

TODO: describe [Problem.feature_encoder](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L173) which is a dict of encoders that have `encode` and `decode` functions.
All `Problem`s are expected to generate 2 sharded `TFRecords` files - 1 for
training and 1 for evaluation - with `tensorflow.Example` protocol buffers. The
expected names of the files are given by `Problem.{training, dev}_filepaths`.
Typically, the features in the `Example` will be `"inputs"` and `"targets"`;
however, some tasks have a different on-disk representation that is converted to
`"inputs"` and `"targets"` online in the input pipeline (e.g. image features are
typically stored with features `"image/encoded"` and `"image/format"` and the
decoding happens in the input pipeline).

## Modalities
For tasks that require a vocabulary, this is also the point at which the
vocabulary is generated and all examples are encoded.

TODO: describe [Modality](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/modality.py#L30) which has `bottom` and `top` but also sharded versions and one for targets.
There are several utility functions in
[`generator_utils`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/generator_utils.py)
that are commonly used by `Problem`s to generate data. Several are highlighted
below:

* `generate_dataset_and_shuffle`: given 2 generators, 1 for training and 1 for
eval, yielding dictionaries of `<feature name, list< int or float or
string >>`, will produce sharded and shuffled `TFRecords` files with
`tensorflow.Example` protos.
* `maybe_download`: downloads a file at a URL to the given directory and
filename (see `maybe_download_from_drive` if the URL points to Google
Drive).
* `get_or_generate_vocab_inner`: given a target vocabulary size and a
generator that yields lines or tokens from the dataset, will build a
`SubwordTextEncoder` along with a backing vocabulary file that can be used
to map input strings to lists of ids.
[`SubwordTextEncoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py)
uses word pieces and its encoding is fully invertible.

## Data Input Pipeline

Once the data is produced on disk, training, evaluation, and inference (if
decoding from the dataset) consume it by way of T2T input pipeline. This section
will give an overview of that pipeline with specific attention to the various
hooks in the `Problem` class and the model's `HParams` object (typically
registered in the model's file and specified by the `--hparams_set` flag).

The entire input pipeline is implemented with the new `tf.data.Dataset` API
(previously `tf.contrib.data.Dataset`).

The key function in the codebase for the input pipeline is
[`data_reader.input_pipeline`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/data_reader.py).
The full input function is built in
[`input_fn_builder.build_input_fn`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/input_fn_builder.py)
(which calls `data_reader.input_pipeline`).

### Reading and decoding data

`Problem.dataset_filename` specifies the prefix of the files on disk (they will
be suffixed with `-train` or `-dev` as well as their sharding).

The features read from the files and their decoding is specified by
`Problem.example_reading_spec`, which returns 2 items:

1. Dict mapping from on-disk feature name to on-disk types (`VarLenFeature` or
`FixedLenFeature`.
2. Dict mapping output feature name to decoder. This return value is optional
and is only needed for tasks whose features may require additional decoding
(e.g. images). You can find the available decoders in
`tf.contrib.slim.tfexample_decoder`.

At this point in the input pipeline, the example is a `dict<feature name,
Tensor>`.

### Preprocessing

The read `Example` now runs through `Problem.preprocess_example`, which by
default runs
[`problem.preprocess_example_common`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py),
which may truncate the inputs/targets or prepend to targets, governed by some
hyperparameters.

### Batching

Examples are bucketed by sequence length and then batched out of those buckets.
This significantly improves performance over a naive batching scheme for
variable length sequences because each example in a batch must be padded to
match the example with the maximum length in the batch.

There are several hyperparameters that affect how examples are batched together:

* `hp.batch_size`: this is the approximate total number of tokens in the batch
(i.e. for a sequence problem, long sequences will have smaller actual batch
size and short sequences will have a larger actual batch size in order to
generally have an equal number of tokens in the batch).
* `hp.max_length`: sequences with length longer than this will be dropped
during training (and also during eval if `hp.eval_drop_long_sequences` is
`True`). If not set, the maximum length of examples is set to
`hp.batch_size`.
* `hp.batch_size_multiplier`: multiplier for the maximum length
* `hp.min_length_bucket`: example length for the smallest bucket (i.e. the
smallest bucket will bucket examples up to this length).
* `hp.length_bucket_step`: controls how spaced out the length buckets are.

## Building the Model

At this point, the input features typically have `"inputs"` and `"targets"`,
each of which is a batched 4-D Tensor (e.g. of shape `[batch_size,
sequence_length, 1, 1]` for text input or `[batch_size, height, width, 3]` for
image input).

A `T2TModel` is composed of transforms of the input features by `Modality`s,
then the body of the model, then transforms of the model output to predictions
by a `Modality`, and then a loss (during training).

The `Modality` types for the various input features and for the target are
specified in `Problem.hparams`. A `Modality` is a feature adapter that enables
models to be agnostic to input/output spaces. You can see the various
`Modality`s in
[`modalities.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/layers/modalities.py).

The sketch structure of a T2T model is as follows:

```python
features = {...} # output from the input pipeline
input_modaly = ... # specified in Problem.hparams
target_modality = ... # specified in Problem.hparams

transformed_features = {}
transformed_features["inputs"] = input_modality.bottom(
features["inputs"])
transformed_features["targets"] = target_modality.targets_bottom(
features["targets"]) # for autoregressive models

body_outputs = model.model_fn_body(transformed_features)

predictions = target_modality.top(body_outputs, features["targets"])
loss = target_modality.loss(predictions, features["targets"])
```

Most `T2TModel`s only override `model_fn_body`.

## Training, Eval, Inference modes

Both the input function and model functions take a mode in the form of a
`tf.estimator.ModeKeys`, which allows the functions to behave differently in
different modes.

In training, the model function constructs an optimizer and minimizes the loss.

In evaluation, the model function constructs the evaluation metrics specified by
`Problem.eval_metrics`.

In inference, the model function outputs predictions.

## `Estimator` and `Experiment`

With the input function and model functions constructed, the actual training
loop and related services (checkpointing, summaries, continuous evaluation,
etc.) are all handled by `Estimator` and `Experiment` objects, constructed in
[`trainer_utils.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/trainer_utils.py).

## Decoding

* [`decoding.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/decoding.py)

TODO(rsepassi): Explain decoding (interactive, from file, and from dataset) and
`Problem.feature_encoders`.
9 changes: 2 additions & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ documentation, from basic tutorials to full code documentation.

## Deep Dive

* [Life of an Example](example_life.md): how all parts of T2T are connected and work together
* [Life of an Example](example_life.md): how all parts of T2T are connected and
work together
* [Distributed Training](distributed_training.md)

## Code documentation

See our
[README](https://github.com/tensorflow/tensor2tensor/blob/master/README.md)
for now, code docs coming.
Loading