File tree Expand file tree Collapse file tree
PyTorch/LanguageModeling/BERT Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -291,6 +291,10 @@ def setup_training(args):
291291 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
292292 torch .distributed .init_process_group (backend = 'nccl' , init_method = 'env://' )
293293 args .n_gpu = 1
294+
295+ if args .gradient_accumulation_steps == 1 :
296+ args .allreduce_post_accumulation = False
297+ args .allreduce_post_accumulation_fp16 = False
294298
295299 if is_main_process ():
296300 dllogger .init (backends = [dllogger .JSONStreamBackend (verbosity = dllogger .Verbosity .VERBOSE ,
Original file line number Diff line number Diff line change @@ -21,6 +21,13 @@ def get_rank():
2121 return 0
2222 return dist .get_rank ()
2323
24+ def get_world_size ():
25+ if not dist .is_available ():
26+ return 1
27+ if not dist .is_initialized ():
28+ return 1
29+ return dist .get_world_size ()
30+
2431def is_main_process ():
2532 return get_rank () == 0
2633
@@ -34,4 +41,4 @@ def format_step(step):
3441 s += "Training Iteration: {} " .format (step [1 ])
3542 if len (step ) > 2 :
3643 s += "Validation Iteration: {} " .format (step [2 ])
37- return s
44+ return s
You can’t perform that action at this time.
0 commit comments