Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions PyTorch/LanguageModeling/BERT/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,14 @@ def prepare_model_and_optimizer(args, device):
model = torch.nn.DataParallel(model)

return model, optimizer, lr_scheduler, checkpoint, global_step

def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):

global skipped_steps
if args.allreduce_post_accumulation:
# manually allreduce gradients after all accumulation steps
# check for Inf/NaN
# 1. allocate an uninitialized buffer for flattened gradient
scaler = _amp_state.loss_scalers[0]
loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1
master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
flat_grad_size = sum(p.numel() for p in master_grads)
allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
Expand All @@ -381,21 +380,24 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
amp_C.multi_tensor_scale(65536,
overflow_buf,
[master_grads, allreduced_views],
scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
loss_scale / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient
torch.distributed.all_reduce(flat_raw)
# 4. combine unscaling and unflattening of allreduced gradient
overflow_buf.zero_()
amp_C.multi_tensor_scale(65536,
overflow_buf,
[allreduced_views, master_grads],
1./scaler.loss_scale())
1./loss_scale)
# 5. update loss scale
scaler = _amp_state.loss_scalers[0]
old_overflow_buf = scaler._overflow_buf
scaler._overflow_buf = overflow_buf
had_overflow = scaler.update_scale()
scaler._overfloat_buf = old_overflow_buf
if args.fp16:
scaler = _amp_state.loss_scalers[0]
old_overflow_buf = scaler._overflow_buf
scaler._overflow_buf = overflow_buf
had_overflow = scaler.update_scale()
scaler._overfloat_buf = old_overflow_buf
else:
had_overflow = 0
# 6. call optimizer step function
if had_overflow == 0:
optimizer.step()
Expand All @@ -404,6 +406,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
# Overflow detected, print message and clear gradients
skipped_steps += 1
if is_main_process():
scaler = _amp_state.loss_scalers[0]
dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()})
if _amp_state.opt_properties.master_weights:
for param in optimizer._amp_stash.all_fp32_from_fp16_params:
Expand Down Expand Up @@ -595,7 +598,9 @@ def main():
now = time.time()
args, final_loss, train_time_raw = main()
gpu_count = args.n_gpu
args.max_steps += args.phase1_end_step if args.phase2
args.max_steps += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
if args.resume_step == -1:
args.resume_step = 0
if torch.distributed.is_initialized():
gpu_count = torch.distributed.get_world_size()
if is_main_process():
Expand Down