114114
115115flags .DEFINE_bool ("use_xla" , False , "Whether to enable XLA JIT compilation." )
116116
117- flags .DEFINE_bool ("fastmath" , False , "Whether to enable loss scaler for fasthmath ops." )
118-
119117flags .DEFINE_bool ("amp" , False , "Whether to enable AMP ops." )
120118
121- flags .DEFINE_bool ("amp_fastmath" , False , "Whether to enable AMP fasthmath ops." )
122-
123119# report samples/sec, total loss and learning rate during training
124120class _LogSessionRunHook (tf .train .SessionRunHook ):
125121 def __init__ (self , global_batch_size , display_every = 10 , hvd_rank = - 1 ):
126122 self .global_batch_size = global_batch_size
127123 self .display_every = display_every
128124 self .hvd_rank = hvd_rank
129125 def after_create_session (self , session , coord ):
130- if FLAGS .use_fp16 or FLAGS .fastmath or FLAGS . amp or FLAGS . amp_fastmath :
126+ if FLAGS .use_fp16 or FLAGS .amp :
131127 print (' Step samples/sec MLM Loss NSP Loss Loss Learning-rate Loss-scaler' )
132128 else :
133129 print (' Step samples/sec MLM Loss NSP Loss Loss Learning-rate' )
134130 self .elapsed_secs = 0.
135131 self .count = 0
136132 def before_run (self , run_context ):
137133 self .t0 = time .time ()
138- if FLAGS .use_fp16 or FLAGS .fastmath or FLAGS . amp or FLAGS . amp_fastmath :
134+ if FLAGS .use_fp16 or FLAGS .amp :
139135 return tf .train .SessionRunArgs (
140136 fetches = ['step_update:0' , 'total_loss:0' ,
141137 'learning_rate:0' , 'nsp_loss:0' ,
@@ -148,7 +144,7 @@ def before_run(self, run_context):
148144 def after_run (self , run_context , run_values ):
149145 self .elapsed_secs += time .time () - self .t0
150146 self .count += 1
151- if FLAGS .use_fp16 or FLAGS .fastmath or FLAGS . amp or FLAGS . amp_fastmath :
147+ if FLAGS .use_fp16 or FLAGS .amp :
152148 global_step , total_loss , lr , nsp_loss , mlm_loss , loss_scaler = run_values .results
153149 else :
154150 global_step , total_loss , lr , nsp_loss , mlm_loss = run_values .results
@@ -157,14 +153,14 @@ def after_run(self, run_context, run_values):
157153 dt = self .elapsed_secs / self .count
158154 img_per_sec = self .global_batch_size / dt
159155 if self .hvd_rank >= 0 :
160- if FLAGS .use_fp16 or FLAGS .fastmath or FLAGS . amp or FLAGS . amp_fastmath :
156+ if FLAGS .use_fp16 or FLAGS .amp :
161157 print ('%2d :: %6i %11.1f %10.4e %10.4e %6.3f %6.4e %6.4e' %
162158 (self .hvd_rank , print_step , img_per_sec , mlm_loss , nsp_loss , total_loss , lr , loss_scaler ))
163159 else :
164160 print ('%2d :: %6i %11.1f %10.4e %10.4e %6.3f %6.4e' %
165161 (self .hvd_rank , print_step , img_per_sec , mlm_loss , nsp_loss , total_loss , lr ))
166162 else :
167- if FLAGS .use_fp16 or FLAGS .fastmath or FLAGS . amp or FLAGS . amp_fastmath :
163+ if FLAGS .use_fp16 or FLAGS .amp :
168164 print ('%6i %11.1f %10.4e %10.4e %6.3f %6.4e %6.4e' %
169165 (print_step , img_per_sec , mlm_loss , nsp_loss , total_loss , lr , loss_scaler ))
170166 else :
@@ -247,7 +243,7 @@ def tpu_scaffold():
247243 if mode == tf .estimator .ModeKeys .TRAIN :
248244 train_op = optimization .create_optimizer (
249245 total_loss , learning_rate , num_train_steps , num_warmup_steps , use_tpu ,
250- hvd , FLAGS .use_fp16 , FLAGS .fastmath , FLAGS . amp , FLAGS . amp_fastmath )
246+ hvd , FLAGS .use_fp16 , FLAGS .amp )
251247
252248 output_spec = tf .contrib .tpu .TPUEstimatorSpec (
253249 mode = mode ,
@@ -483,24 +479,8 @@ def main(_):
483479 if not FLAGS .do_train and not FLAGS .do_eval :
484480 raise ValueError ("At least one of `do_train` or `do_eval` must be True." )
485481
486- if FLAGS .fastmath and FLAGS .amp :
487- raise ValueError ("Only one of fastmath or amp must be True." )
488-
489- if FLAGS .fastmath and FLAGS .amp_fastmath :
490- raise ValueError ("Only one of fastmath or amp_fastmath must be True." )
491-
492- if FLAGS .amp and FLAGS .amp_fastmath :
493- raise ValueError ("Only one of amp or amp_fastmath must be True." )
494-
495- if FLAGS .fastmath :
496- os .environ ["TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32" ] = "1"
497- os .environ ["TF_ENABLE_CUDNN_TENSOR_OP_MATH_FP32" ] = "1"
498- os .environ ["TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32" ] = "1"
499- elif FLAGS .amp :
500- os .environ ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE" ] = "1"
501- elif FLAGS .amp_fastmath :
482+ if FLAGS .amp :
502483 os .environ ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE" ] = "1"
503- os .environ ["TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL" ] = "TENSOR_CORES_ONLY"
504484
505485 if FLAGS .horovod :
506486 import horovod .tensorflow as hvd
0 commit comments