From feb3af3886d06c8ffd12d52441145115ebf6adfd Mon Sep 17 00:00:00 2001 From: Sharath T S Date: Sun, 15 Mar 2020 17:25:31 -0700 Subject: [PATCH] fp32 and allreduce_post_accumulation compatibility --- .../LanguageModeling/BERT/run_pretraining.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/PyTorch/LanguageModeling/BERT/run_pretraining.py b/PyTorch/LanguageModeling/BERT/run_pretraining.py index cc0c58271..818cb5ecb 100755 --- a/PyTorch/LanguageModeling/BERT/run_pretraining.py +++ b/PyTorch/LanguageModeling/BERT/run_pretraining.py @@ -362,7 +362,6 @@ def prepare_model_and_optimizer(args, device): model = torch.nn.DataParallel(model) return model, optimizer, lr_scheduler, checkpoint, global_step - def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): global skipped_steps @@ -370,7 +369,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): # manually allreduce gradients after all accumulation steps # check for Inf/NaN # 1. allocate an uninitialized buffer for flattened gradient - scaler = _amp_state.loss_scalers[0] + loss_scale = _amp_state.loss_scalers[0].loss_scale() if args.fp16 else 1 master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None] flat_grad_size = sum(p.numel() for p in master_grads) allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32 @@ -381,7 +380,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): amp_C.multi_tensor_scale(65536, overflow_buf, [master_grads, allreduced_views], - scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps)) + loss_scale / (torch.distributed.get_world_size() * args.gradient_accumulation_steps)) # 3. sum gradient across ranks. Because of the predivision, this averages the gradient torch.distributed.all_reduce(flat_raw) # 4. combine unscaling and unflattening of allreduced gradient @@ -389,13 +388,16 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): amp_C.multi_tensor_scale(65536, overflow_buf, [allreduced_views, master_grads], - 1./scaler.loss_scale()) + 1./loss_scale) # 5. update loss scale - scaler = _amp_state.loss_scalers[0] - old_overflow_buf = scaler._overflow_buf - scaler._overflow_buf = overflow_buf - had_overflow = scaler.update_scale() - scaler._overfloat_buf = old_overflow_buf + if args.fp16: + scaler = _amp_state.loss_scalers[0] + old_overflow_buf = scaler._overflow_buf + scaler._overflow_buf = overflow_buf + had_overflow = scaler.update_scale() + scaler._overfloat_buf = old_overflow_buf + else: + had_overflow = 0 # 6. call optimizer step function if had_overflow == 0: optimizer.step() @@ -404,6 +406,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): # Overflow detected, print message and clear gradients skipped_steps += 1 if is_main_process(): + scaler = _amp_state.loss_scalers[0] dllogger.log(step="PARAMETER", data={"loss_scale": scaler.loss_scale()}) if _amp_state.opt_properties.master_weights: for param in optimizer._amp_stash.all_fp32_from_fp16_params: @@ -595,7 +598,9 @@ def main(): now = time.time() args, final_loss, train_time_raw = main() gpu_count = args.n_gpu - args.max_steps += args.phase1_end_step if args.phase2 + args.max_steps += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0 + if args.resume_step == -1: + args.resume_step = 0 if torch.distributed.is_initialized(): gpu_count = torch.distributed.get_world_size() if is_main_process():