Skip to content

Commit 2d55554

Browse files
authored
[ResNet/TF] Fix gradient calculation for sync variable
1 parent 0120131 commit 2d55554

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

  • TensorFlow/Classification/ConvNets/model

TensorFlow/Classification/ConvNets/model/resnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ def __call__(self, features, labels, mode, params):
239239

240240
with tf.device("/cpu:0"):
241241
if hvd_utils.is_using_hvd():
242-
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var")
242+
sync_var = tf.Variable(initial_value=[0], dtype=tf.int32, name="signal_handler_var",
243+
trainable=False)
243244
sync_var_assing = sync_var.assign([1], name="signal_handler_var_set")
244245
sync_var_reset = sync_var.assign([0], name="signal_handler_var_reset")
245246
sync_op = hvd.allreduce(sync_var, op=hvd.Sum, name="signal_handler_all_reduce")

0 commit comments

Comments
 (0)