Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 69 additions & 33 deletions PyTorch/LanguageModeling/BERT/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from schedulers import PolyWarmUpScheduler

from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from utils import is_main_process, format_step
from utils import is_main_process, format_step, get_world_size, get_rank
from apex.parallel import DistributedDataParallel as DDP
from schedulers import LinearWarmUpScheduler
from apex.parallel.distributed import flat_dist_call
Expand All @@ -59,6 +59,18 @@

skipped_steps = 0

# Track whether a SIGTERM (cluster time up) has been handled
timeout_sent = False

import signal
# handle SIGTERM sent from the scheduler and mark so we
# can gracefully save & exit
def signal_handler(sig, frame):
global timeout_sent
timeout_sent = True

signal.signal(signal.SIGTERM, signal_handler)

#Workaround because python functions are not picklable
class WorkerInitObj(object):
def __init__(self, seed):
Expand Down Expand Up @@ -107,6 +119,7 @@ def __getitem__(self, index):

return [input_ids, segment_ids, input_mask,
masked_lm_labels, next_sentence_labels]

class BertPretrainingCriterion(torch.nn.Module):
def __init__(self, vocab_size):
super(BertPretrainingCriterion, self).__init__()
Expand Down Expand Up @@ -241,6 +254,10 @@ def parse_arguments():
type=int,
default=7038,
help="Number of training steps in Phase1 - seq len 128")
parser.add_argument('--init_loss_scale',
type=int,
default=2**20,
help="Initial loss scaler value")
parser.add_argument("--do_train",
default=False,
action='store_true',
Expand All @@ -266,12 +283,18 @@ def setup_training(args):
if args.local_rank == -1:
device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
args.allreduce_post_accumulation = False
args.allreduce_post_accumulation_fp16 = False
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.n_gpu = 1

if args.gradient_accumulation_steps == 1:
args.allreduce_post_accumulation = False
args.allreduce_post_accumulation_fp16 = False

if is_main_process():
dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
Expand Down Expand Up @@ -357,7 +380,7 @@ def prepare_model_and_optimizer(args, device):
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16)
else:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16)
amp._amp_state.loss_scalers[0]._loss_scale = 2**20
amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

model.checkpoint_activations(args.checkpoint_activations)

Expand All @@ -384,7 +407,7 @@ def prepare_model_and_optimizer(args, device):

if args.local_rank != -1:
if not args.allreduce_post_accumulation:
model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size())
model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
else:
flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
elif args.n_gpu > 1:
Expand Down Expand Up @@ -412,7 +435,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
amp_C.multi_tensor_scale(65536,
overflow_buf,
[master_grads, allreduced_views],
loss_scale / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
loss_scale / (get_world_size() * args.gradient_accumulation_steps))
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient
torch.distributed.all_reduce(flat_raw)
# 4. combine unscaling and unflattening of allreduced gradient
Expand Down Expand Up @@ -455,6 +478,7 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
return global_step

def main():
global timeout_sent

args = parse_arguments()

Expand All @@ -476,7 +500,7 @@ def main():
if is_main_process():
dllogger.log(step="PARAMETER", data={"SEED": args.seed})

raw_train_start = time.time()
raw_train_start = None
if args.do_train:
if is_main_process():
dllogger.log(step="PARAMETER", data={"train_start": True})
Expand All @@ -494,57 +518,65 @@ def main():
# Note: We loop infinitely over epochs, termination is handled via iteration count
while True:
thread = None
restored_data_loader = None
if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint:
files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f]
files.sort()
num_files = len(files)
random.shuffle(files)
random.Random(args.seed + epoch).shuffle(files)
f_start_id = 0
else:
f_start_id = checkpoint['files'][0]
files = checkpoint['files'][1:]
args.resume_from_checkpoint = False
num_files = len(files)

# may not exist in all checkpoints
epoch = checkpoint.get('epoch', 0)
restored_dataloader = checkpoint.get('data_loader', None)

shared_file_list = {}

if torch.distributed.is_initialized() and torch.distributed.get_world_size() > num_files:
remainder = torch.distributed.get_world_size() % num_files
data_file = files[(f_start_id*torch.distributed.get_world_size()+torch.distributed.get_rank() + remainder*f_start_id)%num_files]
if torch.distributed.is_initialized() and get_world_size() > num_files:
remainder = get_world_size() % num_files
data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files]
else:
data_file = files[(f_start_id*torch.distributed.get_world_size()+torch.distributed.get_rank())%num_files]
data_file = files[(f_start_id*get_world_size()+get_rank())%num_files]

previous_file = data_file

train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size * args.n_gpu,
num_workers=4, worker_init_fn=worker_init,
pin_memory=True)
# shared_file_list["0"] = (train_dataloader, data_file)
if restored_data_loader is None:
train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size * args.n_gpu,
num_workers=4, worker_init_fn=worker_init,
pin_memory=True)
# shared_file_list["0"] = (train_dataloader, data_file)
else:
train_dataloader = restored_data_loader
restored_data_loader = None

overflow_buf = None
if args.allreduce_post_accumulation:
overflow_buf = torch.cuda.IntTensor([0])

if len(files) == 1:
f_start_id = -1

for f_id in range(f_start_id + 1 , len(files)):


if torch.distributed.get_world_size() > num_files:
data_file = files[(f_id*torch.distributed.get_world_size()+torch.distributed.get_rank() + remainder*f_id)%num_files]
if get_world_size() > num_files:
data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files]
else:
data_file = files[(f_id*torch.distributed.get_world_size()+torch.distributed.get_rank())%num_files]
data_file = files[(f_id*get_world_size()+get_rank())%num_files]

previous_file = data_file

dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init)

train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader

if raw_train_start is None:
raw_train_start = time.time()
for step, batch in enumerate(train_iter):

training_steps += 1
Expand Down Expand Up @@ -579,7 +611,7 @@ def main():
average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
average_loss = average_loss / (last_num_steps * divisor)
if (torch.distributed.is_initialized()):
average_loss /= torch.distributed.get_world_size()
average_loss /= get_world_size()
torch.distributed.all_reduce(average_loss)
final_loss = average_loss.item()
if is_main_process():
Expand All @@ -592,7 +624,7 @@ def main():
average_loss = 0

if global_step >= args.max_steps or training_steps % (
args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0:
args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent:
if is_main_process() and not args.skip_checkpoint:
# Save a trained model
dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step})
Expand All @@ -606,17 +638,21 @@ def main():
torch.save({'model': model_to_save.state_dict(),
'optimizer': optimizer.state_dict(),
'master params': list(amp.master_params(optimizer)),
'files': [f_id] + files}, output_save_file)
'files': [f_id] + files,
'epoch': epoch,
'data_loader': None if global_step >= args.max_steps else train_dataloader}, output_save_file)

most_recent_ckpts_paths.append(output_save_file)
if len(most_recent_ckpts_paths) > 3:
ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
os.remove(ckpt_to_be_removed)

if global_step >= args.max_steps:
# Exiting the training due to hitting max steps, or being sent a
# timeout from the cluster scheduler
if global_step >= args.max_steps or timeout_sent:
del train_dataloader
# thread.join()
return args, final_loss, train_time_raw
return args, final_loss, train_time_raw, global_step

del train_dataloader
# thread.join()
Expand All @@ -630,17 +666,17 @@ def main():
if __name__ == "__main__":

now = time.time()
args, final_loss, train_time_raw = main()
args, final_loss, train_time_raw, global_step = main()
gpu_count = args.n_gpu
args.max_steps += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
global_step += args.phase1_end_step if (args.phase2 and args.resume_step > 0) else 0
if args.resume_step == -1:
args.resume_step = 0
if torch.distributed.is_initialized():
gpu_count = torch.distributed.get_world_size()
gpu_count = get_world_size()
if is_main_process():
e2e_time = time.time() - now
training_perf = args.train_batch_size * args.gradient_accumulation_steps * gpu_count\
* (args.max_steps - args.resume_step + skipped_steps) / train_time_raw
* (global_step - args.resume_step + skipped_steps) / train_time_raw
dllogger.log(step=tuple(), data={"e2e_train_time": e2e_time, "training_sequences_per_second": training_perf,
"final_loss": final_loss, "raw_train_time": train_time_raw })
dllogger.flush()
9 changes: 8 additions & 1 deletion PyTorch/LanguageModeling/BERT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def get_rank():
return 0
return dist.get_rank()

def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()

def is_main_process():
return get_rank() == 0

Expand All @@ -34,4 +41,4 @@ def format_step(step):
s += "Training Iteration: {} ".format(step[1])
if len(step) > 2:
s += "Validation Iteration: {} ".format(step[2])
return s
return s