forked from lcdevelop/ChatBotCourse
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathseq2seq_patch.py
More file actions
143 lines (129 loc) · 6.33 KB
/
seq2seq_patch.py
File metadata and controls
143 lines (129 loc) · 6.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq
def embedding_attention_seq2seq(encoder_inputs,
decoder_inputs,
enc_cell,
dec_cell,
num_encoder_symbols,
num_decoder_symbols,
embedding_size,
num_heads=1,
output_projection=None,
feed_previous=False,
dtype=None,
scope=None,
initial_state_attention=False):
"""Embedding sequence-to-sequence model with attention.
This model first embeds encoder_inputs by a newly created embedding (of shape
[num_encoder_symbols x input_size]). Then it runs an RNN to encode
embedded encoder_inputs into a state vector. It keeps the outputs of this
RNN at every step to use for attention later. Next, it embeds decoder_inputs
by another newly created embedding (of shape [num_decoder_symbols x
input_size]). Then it runs attention decoder, initialized with the last
encoder state, on embedded decoder_inputs and attending to encoder outputs.
Warning: when output_projection is None, the size of the attention vectors
and variables will be made proportional to num_decoder_symbols, can be large.
Args:
encoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
decoder_inputs: A list of 1D int32 Tensors of shape [batch_size].
cell: tf.nn.rnn_cell.RNNCell defining the cell function and size.
num_encoder_symbols: Integer; number of symbols on the encoder side.
num_decoder_symbols: Integer; number of symbols on the decoder side.
embedding_size: Integer, the length of the embedding vector for each symbol.
num_heads: Number of attention heads that read from attention_states.
output_projection: None or a pair (W, B) of output projection weights and
biases; W has shape [output_size x num_decoder_symbols] and B has
shape [num_decoder_symbols]; if provided and feed_previous=True, each
fed previous output will first be multiplied by W and added B.
feed_previous: Boolean or scalar Boolean Tensor; if True, only the first
of decoder_inputs will be used (the "GO" symbol), and all other decoder
inputs will be taken from previous outputs (as in embedding_rnn_decoder).
If False, decoder_inputs are used as given (the standard decoder case).
dtype: The dtype of the initial RNN state (default: tf.float32).
scope: VariableScope for the created subgraph; defaults to
"embedding_attention_seq2seq".
initial_state_attention: If False (default), initial attentions are zero.
If True, initialize the attentions from the initial state and attention
states.
Returns:
A tuple of the form (outputs, state), where:
outputs: A list of the same length as decoder_inputs of 2D Tensors with
shape [batch_size x num_decoder_symbols] containing the generated
outputs.
state: The state of each decoder cell at the final time-step.
It is a 2D Tensor of shape [batch_size x cell.state_size].
"""
with variable_scope.variable_scope(
scope or "embedding_attention_seq2seq", dtype=dtype) as scope:
dtype = scope.dtype
# Encoder.
encoder_cell = enc_cell
encoder_cell = core_rnn_cell.EmbeddingWrapper(
encoder_cell,
embedding_classes=num_encoder_symbols,
embedding_size=embedding_size)
encoder_outputs, encoder_state = rnn.static_rnn(
encoder_cell, encoder_inputs, dtype=dtype)
# First calculate a concatenation of encoder outputs to put attention on.
top_states = [
array_ops.reshape(e, [-1, 1, encoder_cell.output_size]) for e in encoder_outputs
]
attention_states = array_ops.concat(top_states, 1)
# Decoder.
output_size = None
if output_projection is None:
dec_cell = core_rnn_cell.OutputProjectionWrapper(dec_cell, num_decoder_symbols)
output_size = num_decoder_symbols
if isinstance(feed_previous, bool):
return seq2seq.embedding_attention_decoder(
decoder_inputs,
encoder_state,
attention_states,
dec_cell,
num_decoder_symbols,
embedding_size,
num_heads=num_heads,
output_size=output_size,
output_projection=output_projection,
feed_previous=feed_previous,
initial_state_attention=initial_state_attention)
# If feed_previous is a Tensor, we construct 2 graphs and use cond.
def decoder(feed_previous_bool):
reuse = None if feed_previous_bool else True
with variable_scope.variable_scope(
variable_scope.get_variable_scope(), reuse=reuse):
outputs, state = seq2seq.embedding_attention_decoder(
decoder_inputs,
encoder_state,
attention_states,
dec_cell,
num_decoder_symbols,
embedding_size,
num_heads=num_heads,
output_size=output_size,
output_projection=output_projection,
feed_previous=feed_previous_bool,
update_embedding_for_previous=False,
initial_state_attention=initial_state_attention)
state_list = [state]
if nest.is_sequence(state):
state_list = nest.flatten(state)
return outputs + state_list
outputs_and_state = control_flow_ops.cond(feed_previous,
lambda: decoder(True),
lambda: decoder(False))
outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs.
state_list = outputs_and_state[outputs_len:]
state = state_list[0]
if nest.is_sequence(encoder_state):
state = nest.pack_sequence_as(
structure=encoder_state, flat_sequence=state_list)
return outputs_and_state[:outputs_len], state