2424import random
2525import sys
2626
27+ import deepspeed
2728import numpy as np
2829import torch
2930from torch .utils .data import (DataLoader , RandomSampler , SequentialSampler ,
@@ -690,7 +691,15 @@ def main():
690691 default = False ,
691692 action = 'store_true' ,
692693 help = "Whether to enable progressive layer dropping or not" )
693-
694+ parser .add_argument (
695+ '--preln' ,
696+ action = 'store_true' ,
697+ default = False ,
698+ help =
699+ "Switching to the variant of Transformer blocks that use pre-LayerNorm."
700+ )
701+
702+ parser = deepspeed .add_config_arguments (parser )
694703 args = parser .parse_args ()
695704
696705 if args .server_ip and args .server_port :
@@ -809,8 +818,10 @@ def main():
809818 if args .progressive_layer_drop :
810819 print ("BertBaseConfigPreLnLayerDrop" )
811820 from nvidia .modelingpreln_layerdrop import BertForSequenceClassification , BertConfig
821+ elif args .preln :
822+ from nvidia .modelingpreln import BertForSequenceClassification , BertConfig , BertLayer
812823 else :
813- from nvidia .modelingpreln import BertForSequenceClassification , BertConfig
824+ from nvidia .modeling import BertForSequenceClassification , BertConfig , BertLayer
814825
815826 bert_config = BertConfig (** bert_base_model_config )
816827 bert_config .vocab_size = len (tokenizer .vocab )
@@ -859,6 +870,19 @@ def main():
859870 elif n_gpu > 1 :
860871 model = torch .nn .DataParallel (model )
861872
873+ # Patch model with deepspeed transformer kernel
874+ if not args .deepspeed_transformer_kernel :
875+ from deepspeed import replace_transformer_layer
876+ model = deepspeed .module_inject .replace_transformer_layer (
877+ orig_layer_impl = BertLayer ,
878+ model = model ,
879+ micro_batch_size = args .train_batch_size ,
880+ bert_config = bert_config ,
881+ seed = args .seed ,
882+ preln = True ,
883+ fp16 = args .fp16 ,
884+ huggingface = False )
885+
862886 # Prepare optimizer
863887 param_optimizer = list (model .named_parameters ())
864888 no_decay = ['bias' , 'LayerNorm.bias' , 'LayerNorm.weight' ]
@@ -871,29 +895,12 @@ def main():
871895 {'params' : [p for n , p in param_optimizer if any (
872896 nd in n for nd in no_decay )], 'weight_decay' : 0.0 }
873897 ]
874- if args .fp16 :
875- try :
876- from apex .optimizers import FP16_Optimizer
877- from apex .optimizers import FusedAdam
878- except ImportError :
879- raise ImportError (
880- "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." )
881-
882- optimizer = FusedAdam (optimizer_grouped_parameters ,
883- lr = args .learning_rate ,
884- bias_correction = False ,
885- max_grad_norm = 1.0 )
886- if args .loss_scale == 0 :
887- optimizer = FP16_Optimizer (optimizer , dynamic_loss_scale = True )
888- else :
889- optimizer = FP16_Optimizer (
890- optimizer , static_loss_scale = args .loss_scale )
891-
892- else :
893- optimizer = BertAdam (optimizer_grouped_parameters ,
894- lr = args .learning_rate ,
895- warmup = args .warmup_proportion ,
896- t_total = num_train_optimization_steps )
898+
899+ model , optimizer , _ , _ = deepspeed .initialize (
900+ args = args ,
901+ model = model ,
902+ model_parameters = optimizer_grouped_parameters ,
903+ dist_init_required = True )
897904
898905 global_step = 0
899906 nb_tr_steps = 0
0 commit comments