2727from src .train import train_loop , tencent_trick , load_checkpoint , benchmark_train_loop , benchmark_inference_loop
2828from src .data import get_train_loader , get_val_dataset , get_val_dataloader , get_coco_ground_truth
2929
30+ import dllogger as DLLogger
31+
32+
3033# Apex imports
3134try :
3235 from apex .parallel .LARC import LARC
@@ -72,8 +75,8 @@ def make_parser():
7275 help = 'manually set random seed for torch' )
7376 parser .add_argument ('--checkpoint' , type = str , default = None ,
7477 help = 'path to model checkpoint file' )
75- parser .add_argument ('--save' , action = 'store_true' ,
76- help = 'save model checkpoints' )
78+ parser .add_argument ('--save' , type = str , default = None ,
79+ help = 'save model checkpoints in the specified directory ' )
7780 parser .add_argument ('--mode' , type = str , default = 'training' ,
7881 choices = ['training' , 'evaluation' , 'benchmark-training' , 'benchmark-inference' ])
7982 parser .add_argument ('--evaluation' , nargs = '*' , type = int , default = [21 , 31 , 37 , 42 , 48 , 53 , 59 , 64 ],
@@ -89,7 +92,6 @@ def make_parser():
8992 parser .add_argument ('--weight-decay' , '--wd' , type = float , default = 0.0005 ,
9093 help = 'momentum argument for SGD optimizer' )
9194
92- parser .add_argument ('--profile' , type = int , default = None )
9395 parser .add_argument ('--warmup' , type = int , default = None )
9496 parser .add_argument ('--benchmark-iterations' , type = int , default = 20 , metavar = 'N' ,
9597 help = 'Run N iterations while benchmarking (ignored when training and validation)' )
@@ -104,10 +106,14 @@ def make_parser():
104106 ' When it is not provided, pretrained model from torchvision'
105107 ' will be downloaded.' )
106108 parser .add_argument ('--num-workers' , type = int , default = 4 )
107- parser .add_argument ('--amp' , action = 'store_true' )
109+ parser .add_argument ('--amp' , action = 'store_true' ,
110+ help = 'Whether to enable AMP ops. When false, uses TF32 on A100 and FP32 on V100 GPUS.' )
111+ parser .add_argument ('--json-summary' , type = str , default = None ,
112+ help = 'If provided, the json summary will be written to'
113+ 'the specified file.' )
108114
109115 # Distributed
110- parser .add_argument ('--local_rank' , default = 0 , type = int ,
116+ parser .add_argument ('--local_rank' , default = os . getenv ( 'LOCAL_RANK' , 0 ) , type = int ,
111117 help = 'Used for multi-process training. Can either be manually set ' +
112118 'or automatically set by using \' python -m multiproc\' .' )
113119
@@ -222,29 +228,61 @@ def train(train_loop_func, logger, args):
222228 obj ['model' ] = ssd300 .module .state_dict ()
223229 else :
224230 obj ['model' ] = ssd300 .state_dict ()
225- torch .save (obj , './models/epoch_{}.pt' .format (epoch ))
231+ save_path = os .path .join (args .save , f'epoch_{ epoch } .pt' )
232+ torch .save (obj , save_path )
233+ logger .log ('model path' , save_path )
226234 train_loader .reset ()
227- print ('total training time: {}' .format (total_time ))
228-
235+ DLLogger .log ((), { 'total time' : total_time })
236+ logger .log_summary ()
237+
238+
239+ def log_params (logger , args ):
240+ logger .log_params ({
241+ "dataset path" : args .data ,
242+ "epochs" : args .epochs ,
243+ "batch size" : args .batch_size ,
244+ "eval batch size" : args .eval_batch_size ,
245+ "no cuda" : args .no_cuda ,
246+ "seed" : args .seed ,
247+ "checkpoint path" : args .checkpoint ,
248+ "mode" : args .mode ,
249+ "eval on epochs" : args .evaluation ,
250+ "lr decay epochs" : args .multistep ,
251+ "learning rate" : args .learning_rate ,
252+ "momentum" : args .momentum ,
253+ "weight decay" : args .weight_decay ,
254+ "lr warmup" : args .warmup ,
255+ "backbone" : args .backbone ,
256+ "backbone path" : args .backbone_path ,
257+ "num workers" : args .num_workers ,
258+ "AMP" : args .amp ,
259+ "precision" : 'amp' if args .amp else 'fp32' ,
260+ })
229261
230262if __name__ == "__main__" :
231263 parser = make_parser ()
232264 args = parser .parse_args ()
265+ args .local_rank = int (os .environ .get ('LOCAL_RANK' , args .local_rank ))
233266 if args .local_rank == 0 :
234267 os .makedirs ('./models' , exist_ok = True )
235268
236269 torch .backends .cudnn .benchmark = True
237270
271+ # write json only on the main thread
272+ args .json_summary = args .json_summary if args .local_rank == 0 else None
273+
238274 if args .mode == 'benchmark-training' :
239275 train_loop_func = benchmark_train_loop
240- logger = BenchLogger ('Training benchmark' )
276+ logger = BenchLogger ('Training benchmark' , json_output = args . json_summary )
241277 args .epochs = 1
242278 elif args .mode == 'benchmark-inference' :
243279 train_loop_func = benchmark_inference_loop
244- logger = BenchLogger ('Inference benchmark' )
280+ logger = BenchLogger ('Inference benchmark' , json_output = args . json_summary )
245281 args .epochs = 1
246282 else :
247283 train_loop_func = train_loop
248- logger = Logger ('Training logger' , print_freq = 1 )
284+ logger = Logger ('Training logger' , print_freq = 1 , json_output = args .json_summary )
285+
286+ log_params (logger , args )
249287
250288 train (train_loop_func , logger , args )
0 commit comments