forked from tensorflow/tensor2tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvisualization.py
More file actions
201 lines (164 loc) · 6.96 KB
/
visualization.py
File metadata and controls
201 lines (164 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared code for visualizing transformer attentions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Dependency imports
import numpy as np
# To register the hparams set
from tensor2tensor import models # pylint: disable=unused-import
from tensor2tensor import problems
from tensor2tensor.utils import registry
from tensor2tensor.utils import trainer_lib
import tensorflow as tf
EOS_ID = 1
class AttentionVisualizer(object):
"""Helper object for creating Attention visualizations."""
def __init__(
self, hparams_set, model_name, data_dir, problem_name, beam_size=1):
inputs, targets, samples, att_mats = build_model(
hparams_set, model_name, data_dir, problem_name, beam_size=beam_size)
# Fetch the problem
ende_problem = problems.problem(problem_name)
encoders = ende_problem.feature_encoders(data_dir)
self.inputs = inputs
self.targets = targets
self.att_mats = att_mats
self.samples = samples
self.encoders = encoders
def encode(self, input_str):
"""Input str to features dict, ready for inference."""
inputs = self.encoders['inputs'].encode(input_str) + [EOS_ID]
batch_inputs = np.reshape(inputs, [1, -1, 1, 1]) # Make it 3D.
return batch_inputs
def decode(self, integers):
"""List of ints to str."""
integers = list(np.squeeze(integers))
return self.encoders['inputs'].decode(integers)
def decode_list(self, integers):
"""List of ints to list of str."""
integers = list(np.squeeze(integers))
return self.encoders['inputs'].decode_list(integers)
def get_vis_data_from_string(self, sess, input_string):
"""Constructs the data needed for visualizing attentions.
Args:
sess: A tf.Session object.
input_string: The input sentence to be translated and visualized.
Returns:
Tuple of (
output_string: The translated sentence.
input_list: Tokenized input sentence.
output_list: Tokenized translation.
att_mats: Tuple of attention matrices; (
enc_atts: Encoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, inp_len, inp_len)
dec_atts: Decoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, out_len)
encdec_atts: Encoder-Decoder attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, inp_len)
)
"""
encoded_inputs = self.encode(input_string)
# Run inference graph to get the translation.
out = sess.run(self.samples, {
self.inputs: encoded_inputs,
})
# Run the decoded translation through the training graph to get the
# attention tensors.
att_mats = sess.run(self.att_mats, {
self.inputs: encoded_inputs,
self.targets: np.reshape(out, [1, -1, 1, 1]),
})
output_string = self.decode(out)
input_list = self.decode_list(encoded_inputs)
output_list = self.decode_list(out)
return output_string, input_list, output_list, att_mats
def build_model(hparams_set, model_name, data_dir, problem_name, beam_size=1):
"""Build the graph required to fetch the attention weights.
Args:
hparams_set: HParams set to build the model with.
model_name: Name of model.
data_dir: Path to directory containing training data.
problem_name: Name of problem.
beam_size: (Optional) Number of beams to use when decoding a translation.
If set to 1 (default) then greedy decoding is used.
Returns:
Tuple of (
inputs: Input placeholder to feed in ids to be translated.
targets: Targets placeholder to feed to translation when fetching
attention weights.
samples: Tensor representing the ids of the translation.
att_mats: Tensors representing the attention weights.
)
"""
hparams = trainer_lib.create_hparams(
hparams_set, data_dir=data_dir, problem_name=problem_name)
translate_model = registry.model(model_name)(
hparams, tf.estimator.ModeKeys.EVAL)
inputs = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name='inputs')
targets = tf.placeholder(tf.int32, shape=(1, None, 1, 1), name='targets')
translate_model({
'inputs': inputs,
'targets': targets,
})
# Must be called after building the training graph, so that the dict will
# have been filled with the attention tensors. BUT before creating the
# inference graph otherwise the dict will be filled with tensors from
# inside a tf.while_loop from decoding and are marked unfetchable.
att_mats = get_att_mats(translate_model)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
samples = translate_model.infer({
'inputs': inputs,
}, beam_size=beam_size)['outputs']
return inputs, targets, samples, att_mats
def get_att_mats(translate_model):
"""Get's the tensors representing the attentions from a build model.
The attentions are stored in a dict on the Transformer object while building
the graph.
Args:
translate_model: Transformer object to fetch the attention weights from.
Returns:
Tuple of attention matrices; (
enc_atts: Encoder self attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, inp_len, inp_len)
dec_atts: Decoder self attetnion weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, out_len)
encdec_atts: Encoder-Decoder attention weights.
A list of `num_layers` numpy arrays of size
(batch_size, num_heads, out_len, inp_len)
)
"""
enc_atts = []
dec_atts = []
encdec_atts = []
prefix = 'transformer/body/'
postfix = '/multihead_attention/dot_product_attention'
for i in range(translate_model.hparams.num_hidden_layers):
enc_att = translate_model.attention_weights[
'%sencoder/layer_%i/self_attention%s' % (prefix, i, postfix)]
dec_att = translate_model.attention_weights[
'%sdecoder/layer_%i/self_attention%s' % (prefix, i, postfix)]
encdec_att = translate_model.attention_weights[
'%sdecoder/layer_%i/encdec_attention%s' % (prefix, i, postfix)]
enc_atts.append(enc_att)
dec_atts.append(dec_att)
encdec_atts.append(encdec_att)
return enc_atts, dec_atts, encdec_atts