@@ -123,6 +123,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
123123 while not mon_sess.should_stop():
124124 mon_sess.run(training_op)
125125 ```
126+
127+ To use SyncReplicasOptimizer with an `Estimator`, you need to send
128+ sync_replicas_hook while calling the fit.
129+ ```
130+ my_estimator = DNNClassifier(..., optimizer=opt)
131+ my_estimator.fit(..., hooks=[sync_replicas_hook])
132+ ```
126133 """
127134
128135 def __init__ (self ,
@@ -418,34 +425,42 @@ def get_init_tokens_op(self, num_tokens=-1):
418425
419426 def make_session_run_hook (self , is_chief , num_tokens = - 1 ):
420427 """Creates a hook to handle SyncReplicasHook ops such as initialization."""
421- if is_chief :
422- return _SyncReplicasOptimizerHook (self .chief_init_op ,
423- self .ready_for_local_init_op ,
424- self .get_chief_queue_runner (),
425- self .get_init_tokens_op (num_tokens ))
426-
427- return _SyncReplicasOptimizerHook (self .local_step_init_op ,
428- self .ready_for_local_init_op , None , None )
428+ return _SyncReplicasOptimizerHook (self , is_chief , num_tokens )
429429
430430
431431class _SyncReplicasOptimizerHook (session_run_hook .SessionRunHook ):
432432 """A SessionRunHook handles ops related to SyncReplicasOptimizer."""
433433
434- def __init__ (self , local_init_op , ready_for_local_init_op , q_runner ,
435- init_tokens_op ):
434+ def __init__ (self , sync_optimizer , is_chief , num_tokens ):
436435 """Creates hook to handle SyncReplicaOptimizer initialization ops.
437436
438437 Args:
439- local_init_op: Either `SyncReplicasOptimizer.chief_init_op` or
440- `SyncReplicasOptimizer.local_step_init_op`.
441- ready_for_local_init_op: `SyncReplicasOptimizer.ready_for_local_init_op`
442- q_runner: Either `SyncReplicasOptimizer.get_chief_queue_runner` or `None`
443- init_tokens_op: `SyncReplicasOptimizer.get_init_tokens_op` or None
438+ sync_optimizer: `SyncReplicasOptimizer` which this hook will initialize.
439+ is_chief: `Bool`, whether is this a chief replica or not.
440+ num_tokens: Number of tokens to add to the queue.
444441 """
445- self ._local_init_op = local_init_op
446- self ._ready_for_local_init_op = ready_for_local_init_op
447- self ._q_runner = q_runner
448- self ._init_tokens_op = init_tokens_op
442+ self ._sync_optimizer = sync_optimizer
443+ self ._is_chief = is_chief
444+ self ._num_tokens = num_tokens
445+
446+ def begin (self ):
447+ if self ._sync_optimizer ._gradients_applied is False : # pylint: disable=protected-access
448+ raise ValueError (
449+ "SyncReplicasOptimizer.apply_gradient should be called before using "
450+ "the hook." )
451+ if self ._is_chief :
452+ self ._local_init_op = self ._sync_optimizer .chief_init_op
453+ self ._ready_for_local_init_op = (
454+ self ._sync_optimizer .ready_for_local_init_op )
455+ self ._q_runner = self ._sync_optimizer .get_chief_queue_runner ()
456+ self ._init_tokens_op = self ._sync_optimizer .get_init_tokens_op (
457+ self ._num_tokens )
458+ else :
459+ self ._local_init_op = self ._sync_optimizer .local_step_init_op
460+ self ._ready_for_local_init_op = (
461+ self ._sync_optimizer .ready_for_local_init_op )
462+ self ._q_runner = None
463+ self ._init_tokens_op = None
449464
450465 def after_create_session (self , session , coord ):
451466 """Runs SyncReplicasOptimizer initialization ops."""
0 commit comments