@@ -362,15 +362,14 @@ def prepare_model_and_optimizer(args, device):
362362 model = torch .nn .DataParallel (model )
363363
364364 return model , optimizer , lr_scheduler , checkpoint , global_step
365-
366365def take_optimizer_step (args , optimizer , model , overflow_buf , global_step ):
367366
368367 global skipped_steps
369368 if args .allreduce_post_accumulation :
370369 # manually allreduce gradients after all accumulation steps
371370 # check for Inf/NaN
372371 # 1. allocate an uninitialized buffer for flattened gradient
373- scaler = _amp_state .loss_scalers [0 ]
372+ loss_scale = _amp_state .loss_scalers [0 ]. loss_scale () if args . fp16 else 1
374373 master_grads = [p .grad for p in amp .master_params (optimizer ) if p .grad is not None ]
375374 flat_grad_size = sum (p .numel () for p in master_grads )
376375 allreduce_dtype = torch .float16 if args .allreduce_post_accumulation_fp16 else torch .float32
@@ -381,21 +380,24 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
381380 amp_C .multi_tensor_scale (65536 ,
382381 overflow_buf ,
383382 [master_grads , allreduced_views ],
384- scaler . loss_scale () / (torch .distributed .get_world_size () * args .gradient_accumulation_steps ))
383+ loss_scale / (torch .distributed .get_world_size () * args .gradient_accumulation_steps ))
385384 # 3. sum gradient across ranks. Because of the predivision, this averages the gradient
386385 torch .distributed .all_reduce (flat_raw )
387386 # 4. combine unscaling and unflattening of allreduced gradient
388387 overflow_buf .zero_ ()
389388 amp_C .multi_tensor_scale (65536 ,
390389 overflow_buf ,
391390 [allreduced_views , master_grads ],
392- 1. / scaler . loss_scale () )
391+ 1. / loss_scale )
393392 # 5. update loss scale
394- scaler = _amp_state .loss_scalers [0 ]
395- old_overflow_buf = scaler ._overflow_buf
396- scaler ._overflow_buf = overflow_buf
397- had_overflow = scaler .update_scale ()
398- scaler ._overfloat_buf = old_overflow_buf
393+ if args .fp16 :
394+ scaler = _amp_state .loss_scalers [0 ]
395+ old_overflow_buf = scaler ._overflow_buf
396+ scaler ._overflow_buf = overflow_buf
397+ had_overflow = scaler .update_scale ()
398+ scaler ._overfloat_buf = old_overflow_buf
399+ else :
400+ had_overflow = 0
399401 # 6. call optimizer step function
400402 if had_overflow == 0 :
401403 optimizer .step ()
@@ -404,6 +406,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
404406 # Overflow detected, print message and clear gradients
405407 skipped_steps += 1
406408 if is_main_process ():
409+ scaler = _amp_state .loss_scalers [0 ]
407410 dllogger .log (step = "PARAMETER" , data = {"loss_scale" : scaler .loss_scale ()})
408411 if _amp_state .opt_properties .master_weights :
409412 for param in optimizer ._amp_stash .all_fp32_from_fp16_params :
@@ -595,7 +598,9 @@ def main():
595598 now = time .time ()
596599 args , final_loss , train_time_raw = main ()
597600 gpu_count = args .n_gpu
598- args .max_steps += args .phase1_end_step if args .phase2
601+ args .max_steps += args .phase1_end_step if (args .phase2 and args .resume_step > 0 ) else 0
602+ if args .resume_step == - 1 :
603+ args .resume_step = 0
599604 if torch .distributed .is_initialized ():
600605 gpu_count = torch .distributed .get_world_size ()
601606 if is_main_process ():
0 commit comments