|
28 | 28 | from horovod.tensorflow.compression import Compression |
29 | 29 |
|
30 | 30 | def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, manual_fp16=False, use_fp16=False, num_accumulation_steps=1, |
31 | | - optimizer_type="adam", allreduce_post_accumulation=False): |
| 31 | + optimizer_type="adam", allreduce_post_accumulation=False, init_loss_scale=2**32): |
32 | 32 | """Creates an optimizer training op.""" |
33 | 33 | global_step = tf.compat.v1.train.get_or_create_global_step() |
34 | 34 |
|
@@ -96,11 +96,11 @@ def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, hvd=None, |
96 | 96 | if hvd is not None and (num_accumulation_steps == 1 or (not allreduce_post_accumulation)): |
97 | 97 | optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True, compression=Compression.fp16 if use_fp16 or manual_fp16 else Compression.none) |
98 | 98 | if use_fp16: |
99 | | - loss_scaler = tf.train.experimental.DynamicLossScale(initial_loss_scale=2**32, increment_period=1000, multiplier=2.0) |
| 99 | + loss_scaler = tf.train.experimental.DynamicLossScale(initial_loss_scale=init_loss_scale, increment_period=1000, multiplier=2.0) |
100 | 100 | optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler) |
101 | 101 | loss_scale_value = tf.identity(loss_scaler(), name="loss_scale") |
102 | 102 | if manual_fp16: |
103 | | - loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=2 ** 32, |
| 103 | + loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=init_loss_scale, |
104 | 104 | incr_every_n_steps=1000, |
105 | 105 | decr_every_n_nan_or_inf=2, |
106 | 106 | decr_ratio=0.5) |
@@ -157,7 +157,7 @@ def update(accum_vars): |
157 | 157 | lambda: update(accum_vars), lambda: tf.no_op()) |
158 | 158 |
|
159 | 159 | new_global_step = tf.cond(tf.math.logical_and(update_step, |
160 | | - tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool)) if hvd is not None else batch_finite, |
| 160 | + tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if hvd is not None else batch_finite), |
161 | 161 | lambda: global_step+1, |
162 | 162 | lambda: global_step) |
163 | 163 | new_global_step = tf.identity(new_global_step, name='step_update') |
|
0 commit comments