Skip to content

Commit 9df464f

Browse files
authored
[BERT/PyT] stop and resume, single gpu and timing fixes. (NVIDIA#509)
* stop and resume, single gpu and timing fixes. * Update utils.py * accumulation features check
1 parent 3aae020 commit 9df464f

2 files changed

Lines changed: 12 additions & 1 deletion

File tree

PyTorch/LanguageModeling/BERT/run_pretraining.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,10 @@ def setup_training(args):
291291
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
292292
torch.distributed.init_process_group(backend='nccl', init_method='env://')
293293
args.n_gpu = 1
294+
295+
if args.gradient_accumulation_steps == 1:
296+
args.allreduce_post_accumulation = False
297+
args.allreduce_post_accumulation_fp16 = False
294298

295299
if is_main_process():
296300
dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,

PyTorch/LanguageModeling/BERT/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def get_rank():
2121
return 0
2222
return dist.get_rank()
2323

24+
def get_world_size():
25+
if not dist.is_available():
26+
return 1
27+
if not dist.is_initialized():
28+
return 1
29+
return dist.get_world_size()
30+
2431
def is_main_process():
2532
return get_rank() == 0
2633

@@ -34,4 +41,4 @@ def format_step(step):
3441
s += "Training Iteration: {} ".format(step[1])
3542
if len(step) > 2:
3643
s += "Validation Iteration: {} ".format(step[2])
37-
return s
44+
return s

0 commit comments

Comments
 (0)