Skip to content

Commit 793b92d

Browse files
authored
[BERT/PyT] fp32 and allreduce_post_accumulation compatibility (NVIDIA#422)
1 parent b03375b commit 793b92d

1 file changed

Lines changed: 15 additions & 10 deletions

File tree

PyTorch/LanguageModeling/BERT/run_pretraining.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,14 @@ def prepare_model_and_optimizer(args, device):
362362
model = torch.nn.DataParallel(model)
363363

364364
return model, optimizer, lr_scheduler, checkpoint, global_step
365-
366365
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
367366

368367
global skipped_steps
369368
if args.allreduce_post_accumulation:
370369
# manually allreduce gradients after all accumulation steps
371370
# check for Inf/NaN
372371
# 1. allocate an uninitialized buffer for flattened gradient
373-
scaler = _amp_state.loss_scalers[0]
372+
loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1
374373
master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
375374
flat_grad_size = sum(p.numel() for p in master_grads)
376375
allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
@@ -381,21 +380,24 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
381380
amp_C.multi_tensor_scale(65536,
382381
overflow_buf,
383382
[master_grads, allreduced_views],
384-
scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
383+
loss_scale / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
385384
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient
386385
torch.distributed.all_reduce(flat_raw)
387386
# 4. combine unscaling and unflattening of allreduced gradient
388387
overflow_buf.zero_()
389388
amp_C.multi_tensor_scale(65536,
390389
overflow_buf,
391390
[allreduced_views, master_grads],
392-
1./scaler.loss_scale())
391+
1./loss_scale)
393392
# 5. update loss scale
394-
scaler = _amp_state.loss_scalers[0]
395-
old_overflow_buf = scaler._overflow_buf
396-
scaler._overflow_buf = overflow_buf
397-
had_overflow = scaler.update_scale()
398-
scaler._overfloat_buf = old_overflow_buf
393+
if args.fp16:
394+
scaler = _amp_state.loss_scalers[0]
395+
old_overflow_buf = scaler._overflow_buf
396+
scaler._overflow_buf = overflow_buf
397+
had_overflow = scaler.update_scale()
398+
scaler._overfloat_buf = old_overflow_buf
399+
else:
400+
had_overflow = 0
399401
# 6. call optimizer step function
400402
if had_overflow == 0:
401403
optimizer.step()
@@ -404,6 +406,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
404406
# Overflow detected, print message and clear gradients
405407
skipped_steps += 1
406408
if is_main_process():
409+
scaler = _amp_state.loss_scalers[0]
407410
dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()})
408411
if _amp_state.opt_properties.master_weights:
409412
for param in optimizer._amp_stash.all_fp32_from_fp16_params:
@@ -595,7 +598,9 @@ def main():
595598
now = time.time()
596599
args, final_loss, train_time_raw = main()
597600
gpu_count = args.n_gpu
598-
args.max_steps += args.phase1_end_step if args.phase2
601+
args.max_steps += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
602+
if args.resume_step == -1:
603+
args.resume_step = 0
599604
if torch.distributed.is_initialized():
600605
gpu_count = torch.distributed.get_world_size()
601606
if is_main_process():

0 commit comments

Comments
 (0)