Skip to content

Commit 1da1f63

Browse files
Lukasz Kaisertensorflower-gardener
authored andcommitted
Use only op_scope, not variable_op_scope, in functional ops since they do not
create variables. Also add missing output_size in EmbeddingWrapper (tensorflow#2852). Change: 125022470
1 parent 67d3c91 commit 1da1f63

5 files changed

Lines changed: 138 additions & 12 deletions

File tree

tensorflow/python/kernel_tests/functional_ops_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
import tensorflow as tf
2424

2525

26+
def simple_scoped_fn(a, x):
27+
"""Simple function: (a, x) -> 2(x+a), but with "2" as a variable in scope."""
28+
with tf.variable_scope("body"):
29+
# Dummy variable, just to check that scoping works as intended.
30+
two = tf.get_variable("two", [], dtype=tf.int32,
31+
initializer=tf.constant_initializer(2))
32+
return tf.mul(tf.add(a, x), two)
33+
34+
2635
class FunctionalOpsTest(tf.test.TestCase):
2736

2837
def testFoldl_Simple(self):
@@ -36,6 +45,24 @@ def testFoldl_Simple(self):
3645
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
3746
self.assertAllEqual(880, r.eval())
3847

48+
def testFoldl_Scoped(self):
49+
with self.test_session() as sess:
50+
with tf.variable_scope("root") as varscope:
51+
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
52+
53+
r = tf.foldl(simple_scoped_fn, elems)
54+
# Check that we have the one variable we asked for here.
55+
self.assertEqual(len(tf.trainable_variables()), 1)
56+
self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
57+
sess.run([tf.initialize_all_variables()])
58+
self.assertAllEqual(208, r.eval())
59+
60+
# Now let's reuse our single variable.
61+
varscope.reuse_variables()
62+
r = tf.foldl(simple_scoped_fn, elems, initializer=10)
63+
self.assertEqual(len(tf.trainable_variables()), 1)
64+
self.assertAllEqual(880, r.eval())
65+
3966
def testFoldr_Simple(self):
4067
with self.test_session():
4168
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
@@ -47,6 +74,24 @@ def testFoldr_Simple(self):
4774
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
4875
self.assertAllEqual(1282, r.eval())
4976

77+
def testFoldr_Scoped(self):
78+
with self.test_session() as sess:
79+
with tf.variable_scope("root") as varscope:
80+
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
81+
82+
r = tf.foldr(simple_scoped_fn, elems)
83+
# Check that we have the one variable we asked for here.
84+
self.assertEqual(len(tf.trainable_variables()), 1)
85+
self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
86+
sess.run([tf.initialize_all_variables()])
87+
self.assertAllEqual(450, r.eval())
88+
89+
# Now let's reuse our single variable.
90+
varscope.reuse_variables()
91+
r = tf.foldr(simple_scoped_fn, elems, initializer=10)
92+
self.assertEqual(len(tf.trainable_variables()), 1)
93+
self.assertAllEqual(1282, r.eval())
94+
5095
def testFold_Grad(self):
5196
with self.test_session():
5297
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
@@ -69,6 +114,34 @@ def testMap_Simple(self):
69114
r = tf.map_fn(lambda x: tf.mul(tf.add(x, 3), 2), elems)
70115
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
71116

117+
def testMap_Scoped(self):
118+
with self.test_session() as sess:
119+
120+
def double_scoped(x):
121+
"""2x with a dummy 2 that is scoped."""
122+
with tf.variable_scope("body"):
123+
# Dummy variable, just to check that scoping works as intended.
124+
two = tf.get_variable("two", [], dtype=tf.int32,
125+
initializer=tf.constant_initializer(2))
126+
return tf.mul(x, two)
127+
128+
with tf.variable_scope("root") as varscope:
129+
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
130+
doubles = np.array([2*x for x in [1, 2, 3, 4, 5, 6]])
131+
132+
r = tf.map_fn(double_scoped, elems)
133+
# Check that we have the one variable we asked for here.
134+
self.assertEqual(len(tf.trainable_variables()), 1)
135+
self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
136+
sess.run([tf.initialize_all_variables()])
137+
self.assertAllEqual(doubles, r.eval())
138+
139+
# Now let's reuse our single variable.
140+
varscope.reuse_variables()
141+
r = tf.map_fn(double_scoped, elems)
142+
self.assertEqual(len(tf.trainable_variables()), 1)
143+
self.assertAllEqual(doubles, r.eval())
144+
72145
def testMap_SimpleNotTensor(self):
73146
with self.test_session():
74147
nums = [1, 2, 3, 4, 5, 6]
@@ -87,6 +160,26 @@ def testScan_Simple(self):
87160
lambda a, x: tf.mul(a, x), elems, initializer=v)
88161
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
89162

163+
def testScan_Scoped(self):
164+
with self.test_session() as sess:
165+
with tf.variable_scope("root") as varscope:
166+
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
167+
168+
r = tf.scan(simple_scoped_fn, elems)
169+
# Check that we have the one variable we asked for here.
170+
self.assertEqual(len(tf.trainable_variables()), 1)
171+
self.assertEqual(tf.trainable_variables()[0].name, "root/body/two:0")
172+
sess.run([tf.initialize_all_variables()])
173+
results = np.array([1, 6, 18, 44, 98, 208])
174+
self.assertAllEqual(results, r.eval())
175+
176+
# Now let's reuse our single variable.
177+
varscope.reuse_variables()
178+
r = tf.scan(simple_scoped_fn, elems, initializer=2)
179+
self.assertEqual(len(tf.trainable_variables()), 1)
180+
results = np.array([6, 16, 38, 84, 178, 368])
181+
self.assertAllEqual(results, r.eval())
182+
90183
def testScan_Control(self):
91184
with self.test_session() as sess:
92185
s = tf.placeholder(tf.float32, shape=[None])

tensorflow/python/kernel_tests/rnn_cell_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,11 @@ def testEmbeddingWrapper(self):
251251
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
252252
x = tf.zeros([1, 1], dtype=tf.int32)
253253
m = tf.zeros([1, 2])
254-
g, new_m = tf.nn.rnn_cell.EmbeddingWrapper(
254+
embedding_cell = tf.nn.rnn_cell.EmbeddingWrapper(
255255
tf.nn.rnn_cell.GRUCell(2),
256-
embedding_classes=3, embedding_size=2)(x, m)
256+
embedding_classes=3, embedding_size=2)
257+
self.assertEqual(embedding_cell.output_size, 2)
258+
g, new_m = embedding_cell(x, m)
257259
sess.run([tf.initialize_all_variables()])
258260
res = sess.run([g, new_m], {x.name: np.array([[1]]),
259261
m.name: np.array([[0.1, 0.1]])})

tensorflow/python/ops/functional_ops.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,15 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
8787
if not callable(fn):
8888
raise TypeError("fn must be callable.")
8989

90-
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
91-
with vs.variable_op_scope([elems], name, "foldl") as varscope:
92-
# Any get_variable calls fn will cache the first call locally
90+
with ops.op_scope([elems], name, "foldl"):
91+
# Any get_variable calls in fn will cache the first call locally
9392
# and not issue repeated network I/O requests for each iteration.
93+
varscope = vs.get_variable_scope()
94+
varscope_caching_device_was_none = False
9495
if varscope.caching_device is None:
96+
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
9597
varscope.set_caching_device(lambda op: op.device)
98+
varscope_caching_device_was_none = True
9699

97100
# Convert elems to tensor array.
98101
elems = ops.convert_to_tensor(elems, name="elems")
@@ -117,6 +120,9 @@ def compute(i, a):
117120
parallel_iterations=parallel_iterations,
118121
back_prop=back_prop,
119122
swap_memory=swap_memory)
123+
124+
if varscope_caching_device_was_none:
125+
varscope.set_caching_device(None)
120126
return r_a
121127

122128

@@ -161,11 +167,15 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
161167
if not callable(fn):
162168
raise TypeError("fn must be callable.")
163169

164-
with vs.variable_op_scope([elems], name, "foldr") as varscope:
165-
# Any get_variable calls fn will cache the first call locally
170+
with ops.op_scope([elems], name, "foldr"):
171+
# Any get_variable calls in fn will cache the first call locally
166172
# and not issue repeated network I/O requests for each iteration.
173+
varscope = vs.get_variable_scope()
174+
varscope_caching_device_was_none = False
167175
if varscope.caching_device is None:
176+
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
168177
varscope.set_caching_device(lambda op: op.device)
178+
varscope_caching_device_was_none = True
169179

170180
# Convert elems to tensor array.
171181
elems = ops.convert_to_tensor(elems, name="elems")
@@ -190,6 +200,9 @@ def compute(i, a):
190200
parallel_iterations=parallel_iterations,
191201
back_prop=back_prop,
192202
swap_memory=swap_memory)
203+
204+
if varscope_caching_device_was_none:
205+
varscope.set_caching_device(None)
193206
return r_a
194207

195208

@@ -232,11 +245,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
232245
if not callable(fn):
233246
raise TypeError("fn must be callable.")
234247

235-
with vs.variable_op_scope([elems], name, "map") as varscope:
236-
# Any get_variable calls fn will cache the first call locally
248+
with ops.op_scope([elems], name, "map"):
249+
# Any get_variable calls in fn will cache the first call locally
237250
# and not issue repeated network I/O requests for each iteration.
251+
varscope = vs.get_variable_scope()
252+
varscope_caching_device_was_none = False
238253
if varscope.caching_device is None:
254+
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
239255
varscope.set_caching_device(lambda op: op.device)
256+
varscope_caching_device_was_none = True
240257

241258
elems = ops.convert_to_tensor(elems, name="elems")
242259
dtype = dtype if dtype else elems.dtype
@@ -263,6 +280,9 @@ def compute(i, ta):
263280
result = r_a.pack()
264281
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate(
265282
result.get_shape()[1:]))
283+
284+
if varscope_caching_device_was_none:
285+
varscope.set_caching_device(None)
266286
return result
267287

268288

@@ -307,11 +327,15 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
307327
if not callable(fn):
308328
raise TypeError("fn must be callable.")
309329

310-
with vs.variable_op_scope([elems], name, "scan") as varscope:
311-
# Any get_variable calls fn will cache the first call locally
330+
with ops.op_scope([elems], name, "scan"):
331+
# Any get_variable calls in fn will cache the first call locally
312332
# and not issue repeated network I/O requests for each iteration.
333+
varscope = vs.get_variable_scope()
334+
varscope_caching_device_was_none = False
313335
if varscope.caching_device is None:
336+
# TODO(ebrevdo): Change to using colocate_with here and in other methods.
314337
varscope.set_caching_device(lambda op: op.device)
338+
varscope_caching_device_was_none = True
315339

316340
# Convert elems to tensor array.
317341
elems = ops.convert_to_tensor(elems, name="elems")
@@ -346,6 +370,9 @@ def compute(i, a, ta):
346370
result = r_a.pack()
347371
result.set_shape(elems.get_shape().with_rank_at_least(1)[0:1].concatenate(
348372
result.get_shape()[1:]))
373+
374+
if varscope_caching_device_was_none:
375+
varscope.set_caching_device(None)
349376
return result
350377

351378

tensorflow/python/ops/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
145145
max_sequence_length = math_ops.reduce_max(sequence_length)
146146

147147
for time, input_ in enumerate(inputs):
148-
if time > 0: vs.get_variable_scope().reuse_variables()
148+
if time > 0: varscope.reuse_variables()
149149
# pylint: disable=cell-var-from-loop
150150
call_cell = lambda: cell(input_, state)
151151
# pylint: enable=cell-var-from-loop

tensorflow/python/ops/rnn_cell.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,10 @@ def __init__(self, cell, embedding_classes, embedding_size, initializer=None):
704704
def state_size(self):
705705
return self._cell.state_size
706706

707+
@property
708+
def output_size(self):
709+
return self._cell.output_size
710+
707711
def __call__(self, inputs, state, scope=None):
708712
"""Run the cell on embedded inputs."""
709713
with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"

0 commit comments

Comments
 (0)