3636""" train fit utility """
3737import logging
3838import math
39+ import glob
3940import os
4041import random
4142import sys
4243import time
44+ import re
4345from itertools import starmap
4446
47+ import signal
48+ import pickle
49+
4550import dllogger
4651import horovod .mxnet as hvd
4752import mxnet as mx
5560from global_metrics import CompositeMeter , MaxMeter , MinMeter , AvgMeter , PercentileMeter
5661
5762
63+ class PartitionSignalHandler ():
64+ def __init__ (self , sync_freq : int = 10 ):
65+ self .step = 0
66+ self .freq = sync_freq
67+
68+ self .t = mx .nd .array ([0 ])
69+
70+ signal .signal (signal .SIGUSR1 , self ._signal_handler )
71+ signal .signal (signal .SIGTERM , self ._signal_handler )
72+
73+ def sync (self ) -> bool :
74+ if self .step % self .freq == 0 :
75+ new_sync = hvd .allreduce (self .t , average = False )
76+ if new_sync [0 ] > 0 :
77+ self .t [0 ] = 1
78+ self .step += 1
79+
80+ return self .should_end ()
81+
82+ def should_end (self ) -> bool :
83+ return bool (self .t [0 ] > 0 )
84+
85+ def _signal_handler (self , signum , frame ):
86+ self .t [0 ] = 1
87+
88+
5889def add_fit_args (parser ):
5990 def int_list (x ):
6091 return list (map (int , x .split (',' )))
@@ -79,7 +110,7 @@ def float_list(x):
79110 help = 'the batch size' )
80111 train .add_argument ('--num-epochs' , type = int , default = 90 ,
81112 help = 'number of epochs' )
82- train .add_argument ('--run-epochs' , type = int , default = - 1 ,
113+ train .add_argument ('--run-epochs' , type = int , default = - 1 ,
83114 help = 'number of epochs to run in single run' )
84115 train .add_argument ('--lr' , type = float , default = 0.1 ,
85116 help = 'initial learning rate' )
@@ -134,7 +165,8 @@ def get_epoch_size(args, kv):
134165
135166def get_lr_scheduler (args ):
136167 def multistep_schedule (x ):
137- lr = args .lr * (args .lr_factor ** (len (list (filter (lambda step : step <= x , args .lr_steps )))))
168+ lr = args .lr * \
169+ (args .lr_factor ** (len (list (filter (lambda step : step <= x , args .lr_steps )))))
138170 warmup_coeff = min (1 , x / args .warmup_epochs )
139171 return warmup_coeff * lr
140172
@@ -164,33 +196,49 @@ def cosine_schedule(x):
164196
165197
166198def load_model (args , model ):
167- if args .load is None :
168- return False
169- model .load_parameters (args .load )
170- logging .info ('Loaded model {}' .format (args .load ))
171- return True
199+ file = list (glob .glob (
200+ f"{ args .workspace } /{ args .model_prefix } _*.params" ))
201+ if len (file ) == 0 :
202+ return 0
172203
204+ file = [x for x in sorted (file ) if "best.params" not in x ][- 1 ]
173205
174- def save_checkpoint (net , epoch , top1 , best_acc , model_prefix , save_frequency , kvstore ):
206+ epoch = re .match (f".*{ args .model_prefix } _([0-9]*)\.params" , file )
207+ if epoch is None :
208+ return 0
209+
210+ epoch = int (epoch .group (1 ))
211+ model .load_parameters (file )
212+ logging .info ('Loaded model {}' .format (file ))
213+ return epoch
214+
215+
216+ def save_checkpoint (net , epoch , top1 , best_acc , model_prefix , workspace , save_frequency , kvstore , force_save = False ):
175217 if model_prefix is None or save_frequency == 0 or ('horovod' in kvstore and hvd .rank () != 0 ):
176218 return
177- if save_frequency > 0 and (epoch + 1 ) % save_frequency == 0 :
219+ if ( save_frequency > 0 and (epoch + 1 ) % save_frequency == 0 ) or force_save :
178220 fname = '{}_{:04}.params' .format (model_prefix , epoch )
221+ fname = os .path .join (workspace , fname )
179222 net .save_parameters (fname )
180- logging .info ('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}' .format (epoch , fname , top1 ))
223+ logging .info ('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}' .format (
224+ epoch , fname , top1 ))
225+
181226 if top1 > best_acc :
182- fname = '{ }_best.params'. format ( model_prefix )
227+ fname = os . path . join ( workspace , f' { model_prefix } _best.params' )
183228 net .save_parameters (fname )
184- logging .info ('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}' .format (epoch , fname , top1 ))
229+ logging .info ('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}' .format (
230+ epoch , fname , top1 ))
185231
186232
187233def model_pred (args , model , image ):
188234 from imagenet_classes import classes
189- output = model (image .reshape (- 1 , * image .shape ))[0 ].softmax ().as_in_context (mx .cpu ())
235+ output = model (image .reshape (- 1 , * image .shape )
236+ )[0 ].softmax ().as_in_context (mx .cpu ())
190237 top = output .argsort (is_ascend = False )[:10 ]
191238 for i , ind in enumerate (top ):
192239 ind = int (ind .asscalar ())
193- logging .info ('{:2d}. {:5.2f}% -> {}' .format (i + 1 , output [ind ].asscalar () * 100 , classes [ind ]))
240+ logging .info ('{:2d}. {:5.2f}% -> {}' .format (i + 1 ,
241+ output [ind ].asscalar () * 100 , classes [ind ]))
194242
195243
196244def reduce_metrics (args , metrics , kvstore ):
@@ -214,7 +262,8 @@ def model_score(args, net, val_data, metric, kvstore):
214262
215263 val_data .reset ()
216264
217- total_batch_size = val_data .batch_size * val_data ._num_gpus * (hvd .size () if 'horovod' in kvstore else 1 )
265+ total_batch_size = val_data .batch_size * val_data ._num_gpus * \
266+ (hvd .size () if 'horovod' in kvstore else 1 )
218267
219268 durations = []
220269 tic = time .time ()
@@ -225,9 +274,11 @@ def model_score(args, net, val_data, metric, kvstore):
225274 o .wait_to_read ()
226275
227276 data = [b .data [0 ] for b in batches ]
228- label = [b .label [0 ][:len (b .data [0 ]) - b .pad ] for b in batches if len (b .data [0 ]) != b .pad ]
277+ label = [b .label [0 ][:len (b .data [0 ]) - b .pad ]
278+ for b in batches if len (b .data [0 ]) != b .pad ]
229279 outputs = [net (X ) for X , b in zip (data , batches )]
230- outputs = [o [:len (b .data [0 ]) - b .pad ] for o , b in zip (outputs , batches ) if len (b .data [0 ]) != b .pad ]
280+ outputs = [o [:len (b .data [0 ]) - b .pad ]
281+ for o , b in zip (outputs , batches ) if len (b .data [0 ]) != b .pad ]
231282 metric .update (label , outputs )
232283
233284 durations .append (time .time () - tic )
@@ -263,21 +314,24 @@ def model_fit(args, net, train_data, eval_metric, optimizer,
263314 loss_metric = ScalarMetric ()
264315
265316 if 'horovod' in kvstore :
266- trainer = hvd .DistributedTrainer (net .collect_params (), optimizer , optimizer_params )
317+ trainer = hvd .DistributedTrainer (
318+ net .collect_params (), optimizer , optimizer_params )
267319 else :
268320 trainer = gluon .Trainer (net .collect_params (), optimizer , optimizer_params ,
269321 kvstore = kv , update_on_kvstore = False )
270322
271323 if args .amp :
272324 amp .init_trainer (trainer )
273-
325+
326+ partition_handler = PartitionSignalHandler (1 )
274327
275328 sparse_label_loss = (args .label_smoothing == 0 and args .mixup == 0 )
276329 loss = gluon .loss .SoftmaxCrossEntropyLoss (sparse_label = sparse_label_loss )
277330 loss .hybridize (static_shape = True , static_alloc = True )
278331
279332 local_batch_size = train_data .batch_size
280- total_batch_size = local_batch_size * train_data ._num_gpus * (hvd .size () if 'horovod' in kvstore else 1 )
333+ total_batch_size = local_batch_size * train_data ._num_gpus * \
334+ (hvd .size () if 'horovod' in kvstore else 1 )
281335 durations = []
282336
283337 epoch_size = get_epoch_size (args , kv )
@@ -287,16 +341,21 @@ def transform_data(images, labels):
287341 if args .mixup != 0 :
288342 coeffs = mx .nd .array (np .random .beta (args .mixup , args .mixup , size = images .shape [0 ])).as_in_context (
289343 images .context )
290- image_coeffs = coeffs .astype (images .dtype , copy = False ).reshape (* coeffs .shape , 1 , 1 , 1 )
291- ret_images = image_coeffs * images + (1 - image_coeffs ) * images [::- 1 ]
344+ image_coeffs = coeffs .astype (
345+ images .dtype , copy = False ).reshape (* coeffs .shape , 1 , 1 , 1 )
346+ ret_images = image_coeffs * images + \
347+ (1 - image_coeffs ) * images [::- 1 ]
292348
293- ret_labels = label_smoothing (labels , args .num_classes , args .label_smoothing )
349+ ret_labels = label_smoothing (
350+ labels , args .num_classes , args .label_smoothing )
294351 label_coeffs = coeffs .reshape (* coeffs .shape , 1 )
295- ret_labels = label_coeffs * ret_labels + (1 - label_coeffs ) * ret_labels [::- 1 ]
352+ ret_labels = label_coeffs * ret_labels + \
353+ (1 - label_coeffs ) * ret_labels [::- 1 ]
296354 else :
297355 ret_images = images
298356 if not sparse_label_loss :
299- ret_labels = label_smoothing (labels , args .num_classes , args .label_smoothing )
357+ ret_labels = label_smoothing (
358+ labels , args .num_classes , args .label_smoothing )
300359 else :
301360 ret_labels = labels
302361
@@ -315,76 +374,87 @@ def transform_data(images, labels):
315374
316375 logging .info ('Starting epoch {}' .format (epoch ))
317376 outputs = []
318- for i , batches in enumerate (train_data ):
319- # synchronize to previous iteration
320- #for o in outputs:
321- # o.wait_to_read()
322-
323- trainer .set_learning_rate (lr_scheduler (epoch + i / epoch_size ))
324-
325- data = [b .data [0 ] for b in batches ]
326- label = [b .label [0 ].as_in_context (b .data [0 ].context ) for b in batches ]
327- orig_label = label
328-
329- data , label = zip (* starmap (transform_data , zip (data , label )))
330-
331- outputs = []
332- Ls = []
333- with ag .record ():
334- for x , y in zip (data , label ):
335- z = net (x )
336- L = loss (z , y )
337- # store the loss and do backward after we have done forward
338- # on all GPUs for better speed on multiple GPUs.
339- Ls .append (L )
340- outputs .append (z )
341-
342- if args .amp :
343- with amp .scale_loss (Ls , trainer ) as scaled_loss :
344- ag .backward (scaled_loss )
377+ if not partition_handler .should_end ():
378+ for i , batches in enumerate (train_data ):
379+ # synchronize to previous iteration
380+ # for o in outputs:
381+ # o.wait_to_read()
382+
383+ trainer .set_learning_rate (lr_scheduler (epoch + i / epoch_size ))
384+
385+ data = [b .data [0 ] for b in batches ]
386+ label = [b .label [0 ].as_in_context (
387+ b .data [0 ].context ) for b in batches ]
388+ orig_label = label
389+
390+ data , label = zip (* starmap (transform_data , zip (data , label )))
391+
392+ outputs = []
393+ Ls = []
394+ with ag .record ():
395+ for x , y in zip (data , label ):
396+ z = net (x )
397+ L = loss (z , y )
398+ # store the loss and do backward after we have done forward
399+ # on all GPUs for better speed on multiple GPUs.
400+ Ls .append (L )
401+ outputs .append (z )
402+
403+ if args .amp :
404+ with amp .scale_loss (Ls , trainer ) as scaled_loss :
405+ ag .backward (scaled_loss )
406+ else :
407+ ag .backward (Ls )
408+
409+ if 'horovod' in kvstore :
410+ trainer .step (local_batch_size )
345411 else :
346- ag .backward (Ls )
347-
348- if 'horovod' in kvstore :
349- trainer .step (local_batch_size )
350- else :
351- trainer .step (total_batch_size )
412+ trainer .step (total_batch_size )
352413
353- loss_metric .update (..., np .mean ([l .asnumpy () for l in Ls ]).item ())
414+ loss_metric .update (..., np .mean (
415+ [l .asnumpy () for l in Ls ]).item ())
354416
355- if args .disp_batches and not (i + 1 ) % args .disp_batches :
356- dllogger_it_data = {
357- 'train.loss' : loss_metric .get ()[1 ],
358- 'train.ips' : args .disp_batches * total_batch_size / (time .time () - btic ),
359- 'train.lr' : trainer .learning_rate
360- }
361- dllogger .log ((epoch , i ), data = dllogger_it_data )
417+ if args .disp_batches and not (i + 1 ) % args .disp_batches :
418+ dllogger_it_data = {
419+ 'train.loss' : loss_metric .get ()[1 ],
420+ 'train.ips' : args .disp_batches * total_batch_size / (time .time () - btic ),
421+ 'train.lr' : trainer .learning_rate
422+ }
423+ dllogger .log ((epoch , i ), data = dllogger_it_data )
362424
363- loss_metric .reset_local ()
364- btic = time .time ()
425+ loss_metric .reset_local ()
426+ btic = time .time ()
365427
366- durations .append (time .time () - tic )
367- tic = time .time ()
428+ durations .append (time .time () - tic )
429+ tic = time .time ()
368430
369431 durations = durations [min (len (durations ) // 10 , 100 ):]
370432 dllogger_epoch_data = {
371433 'train.loss' : loss_metric .get_global ()[1 ],
372434 'train.ips' : total_batch_size / np .mean (durations )
373435 }
436+
437+ should_break = partition_handler .sync ()
374438 if args .mode == 'train_val' :
375439 logging .info ('Validating epoch {}' .format (epoch ))
376- score , duration_stats , _ = model_score (args , net , eval_data , eval_metric , kvstore )
440+ score , duration_stats , _ = model_score (
441+ args , net , eval_data , eval_metric , kvstore )
377442
378443 dllogger_epoch_data .update (
379- starmap (lambda key , val : ('val.{}' .format (key ), val ), zip (* score ))
444+ starmap (lambda key , val : (
445+ 'val.{}' .format (key ), val ), zip (* score ))
380446 )
381447 dllogger_epoch_data .update (
382- starmap (lambda key , val : ('val.{}' .format (key ), val ), duration_stats .items ())
448+ starmap (lambda key , val : ('val.{}' .format (key ), val ),
449+ duration_stats .items ())
383450 )
384451
385452 score = dict (zip (* score ))
386453 accuracy = score .get ('accuracy' , - 1 )
387- save_checkpoint (net , epoch , accuracy , best_accuracy , model_prefix , args .save_frequency , kvstore )
454+ save_checkpoint (net , epoch , accuracy , best_accuracy ,
455+ model_prefix , args .workspace ,
456+ args .save_frequency , kvstore ,
457+ force_save = should_break )
388458 best_accuracy = max (best_accuracy , accuracy )
389459 global_metrics .update_dict (dllogger_epoch_data )
390460 dllogger .log (step = (epoch ,), data = dllogger_epoch_data )
@@ -446,7 +516,8 @@ def fit(args, model, data_loader):
446516 tic = time .time ()
447517 return
448518
449- if not load_model (args , model ):
519+ start_epoch = load_model (args , model )
520+ if start_epoch == 0 :
450521 # all initializers should be specified in the model definition.
451522 # if not, this will raise an error
452523 model .initialize (mx .init .Initializer ())
@@ -516,7 +587,7 @@ def fit(args, model, data_loader):
516587 args ,
517588 model ,
518589 train ,
519- begin_epoch = args . begin_epoch ,
590+ begin_epoch = start_epoch ,
520591 num_epoch = args .num_epochs ,
521592 run_epoch = args .run_epochs ,
522593 eval_data = val ,
@@ -531,15 +602,19 @@ def fit(args, model, data_loader):
531602 )
532603 elif args .mode == 'val' :
533604 for epoch in range (args .num_epochs ): # loop for benchmarking
534- score , duration_stats , durations = model_score (args , model , val , eval_metrics , args .kv_store )
535- dllogger_data = dict (starmap (lambda key , val : ('val.{}' .format (key ), val ), zip (* score )))
605+ score , duration_stats , durations = model_score (
606+ args , model , val , eval_metrics , args .kv_store )
607+ dllogger_data = dict (starmap (lambda key , val : (
608+ 'val.{}' .format (key ), val ), zip (* score )))
536609 dllogger_data .update (
537- starmap (lambda key , val : ('val.{}' .format (key ), val ), duration_stats .items ())
610+ starmap (lambda key , val : ('val.{}' .format (key ), val ),
611+ duration_stats .items ())
538612 )
539613 global_metrics .update_dict (dllogger_data )
540614 for percentile in [50 , 90 , 95 , 99 , 100 ]:
541615 metric_name = 'val.latency_{}' .format (percentile )
542- dllogger_data [metric_name ] = np .percentile (durations , percentile )
616+ dllogger_data [metric_name ] = np .percentile (
617+ durations , percentile )
543618 global_metrics .update_metric (metric_name , durations )
544619 dllogger .log (step = (epoch ,), data = dllogger_data )
545620 else :
0 commit comments