Skip to content

Commit 4903865

Browse files
ispirmustafatensorflower-gardener
authored andcommitted
Integrate SyncReplicasOptimizer with Estimators.
There is an hidden dependency between when 'apply_gradient' and get_chief_queue_runner() are called. This cl postpones creation of the queue to the initialization of the Session. In Estimator, Session is created after forming the graph/training-op. That means it is after the apply_gradient. Change: 147746938
1 parent 6640d3f commit 4903865

2 files changed

Lines changed: 59 additions & 19 deletions

File tree

tensorflow/python/training/sync_replicas_optimizer.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
123123
while not mon_sess.should_stop():
124124
mon_sess.run(training_op)
125125
```
126+
127+
To use SyncReplicasOptimizer with an `Estimator`, you need to send
128+
sync_replicas_hook while calling the fit.
129+
```
130+
my_estimator = DNNClassifier(..., optimizer=opt)
131+
my_estimator.fit(..., hooks=[sync_replicas_hook])
132+
```
126133
"""
127134

128135
def __init__(self,
@@ -418,34 +425,42 @@ def get_init_tokens_op(self, num_tokens=-1):
418425

419426
def make_session_run_hook(self, is_chief, num_tokens=-1):
420427
"""Creates a hook to handle SyncReplicasHook ops such as initialization."""
421-
if is_chief:
422-
return _SyncReplicasOptimizerHook(self.chief_init_op,
423-
self.ready_for_local_init_op,
424-
self.get_chief_queue_runner(),
425-
self.get_init_tokens_op(num_tokens))
426-
427-
return _SyncReplicasOptimizerHook(self.local_step_init_op,
428-
self.ready_for_local_init_op, None, None)
428+
return _SyncReplicasOptimizerHook(self, is_chief, num_tokens)
429429

430430

431431
class _SyncReplicasOptimizerHook(session_run_hook.SessionRunHook):
432432
"""A SessionRunHook handles ops related to SyncReplicasOptimizer."""
433433

434-
def __init__(self, local_init_op, ready_for_local_init_op, q_runner,
435-
init_tokens_op):
434+
def __init__(self, sync_optimizer, is_chief, num_tokens):
436435
"""Creates hook to handle SyncReplicaOptimizer initialization ops.
437436
438437
Args:
439-
local_init_op: Either `SyncReplicasOptimizer.chief_init_op` or
440-
`SyncReplicasOptimizer.local_step_init_op`.
441-
ready_for_local_init_op: `SyncReplicasOptimizer.ready_for_local_init_op`
442-
q_runner: Either `SyncReplicasOptimizer.get_chief_queue_runner` or `None`
443-
init_tokens_op: `SyncReplicasOptimizer.get_init_tokens_op` or None
438+
sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
439+
is_chief: `Bool`, whether is this a chief replica or not.
440+
num_tokens: Number of tokens to add to the queue.
444441
"""
445-
self._local_init_op = local_init_op
446-
self._ready_for_local_init_op = ready_for_local_init_op
447-
self._q_runner = q_runner
448-
self._init_tokens_op = init_tokens_op
442+
self._sync_optimizer = sync_optimizer
443+
self._is_chief = is_chief
444+
self._num_tokens = num_tokens
445+
446+
def begin(self):
447+
if self._sync_optimizer._gradients_applied is False: # pylint: disable=protected-access
448+
raise ValueError(
449+
"SyncReplicasOptimizer.apply_gradient should be called before using "
450+
"the hook.")
451+
if self._is_chief:
452+
self._local_init_op = self._sync_optimizer.chief_init_op
453+
self._ready_for_local_init_op = (
454+
self._sync_optimizer.ready_for_local_init_op)
455+
self._q_runner = self._sync_optimizer.get_chief_queue_runner()
456+
self._init_tokens_op = self._sync_optimizer.get_init_tokens_op(
457+
self._num_tokens)
458+
else:
459+
self._local_init_op = self._sync_optimizer.local_step_init_op
460+
self._ready_for_local_init_op = (
461+
self._sync_optimizer.ready_for_local_init_op)
462+
self._q_runner = None
463+
self._init_tokens_op = None
449464

450465
def after_create_session(self, session, coord):
451466
"""Runs SyncReplicasOptimizer initialization ops."""

tensorflow/python/training/sync_replicas_optimizer_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,5 +277,30 @@ def test3Workers1Backup(self):
277277
sessions[1].run(var_1_g_1))
278278

279279

280+
class SyncReplicasOptimizerHookTest(test.TestCase):
281+
282+
def testErrorIfUsedBeforeMinimizeCalled(self):
283+
opt = training.SyncReplicasOptimizer(
284+
opt=gradient_descent.GradientDescentOptimizer(1.0),
285+
replicas_to_aggregate=1,
286+
total_num_replicas=1)
287+
hook = opt.make_session_run_hook(True)
288+
with self.assertRaisesRegexp(ValueError,
289+
"apply_gradient should be called"):
290+
hook.begin()
291+
292+
def testCanCreatedBeforeMinimizeCalled(self):
293+
"""This behavior is required to be integrated with Estimators."""
294+
opt = training.SyncReplicasOptimizer(
295+
opt=gradient_descent.GradientDescentOptimizer(1.0),
296+
replicas_to_aggregate=1,
297+
total_num_replicas=1)
298+
hook = opt.make_session_run_hook(True)
299+
v = variables.Variable([0.])
300+
global_step = variables.Variable(0, name="global_step", trainable=False)
301+
opt.minimize(v, global_step=global_step)
302+
hook.begin()
303+
304+
280305
if __name__ == "__main__":
281306
test.main()

0 commit comments

Comments
 (0)