Skip to content

Commit 135fbd9

Browse files
hXl3snv-kkudrynski
authored andcommitted
[Convnets/MX] Suspend resume support
1 parent c16a623 commit 135fbd9

File tree

1 file changed

+155
-80
lines changed
  • MxNet/Classification/RN50v1.5

1 file changed

+155
-80
lines changed

MxNet/Classification/RN50v1.5/fit.py

Lines changed: 155 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@
3636
""" train fit utility """
3737
import logging
3838
import math
39+
import glob
3940
import os
4041
import random
4142
import sys
4243
import time
44+
import re
4345
from itertools import starmap
4446

47+
import signal
48+
import pickle
49+
4550
import dllogger
4651
import horovod.mxnet as hvd
4752
import mxnet as mx
@@ -55,6 +60,32 @@
5560
from global_metrics import CompositeMeter, MaxMeter, MinMeter, AvgMeter, PercentileMeter
5661

5762

63+
class PartitionSignalHandler():
64+
def __init__(self, sync_freq: int = 10):
65+
self.step = 0
66+
self.freq = sync_freq
67+
68+
self.t = mx.nd.array([0])
69+
70+
signal.signal(signal.SIGUSR1, self._signal_handler)
71+
signal.signal(signal.SIGTERM, self._signal_handler)
72+
73+
def sync(self) -> bool:
74+
if self.step % self.freq == 0:
75+
new_sync = hvd.allreduce(self.t, average=False)
76+
if new_sync[0] > 0:
77+
self.t[0] = 1
78+
self.step += 1
79+
80+
return self.should_end()
81+
82+
def should_end(self) -> bool:
83+
return bool(self.t[0] > 0)
84+
85+
def _signal_handler(self, signum, frame):
86+
self.t[0] = 1
87+
88+
5889
def add_fit_args(parser):
5990
def int_list(x):
6091
return list(map(int, x.split(',')))
@@ -79,7 +110,7 @@ def float_list(x):
79110
help='the batch size')
80111
train.add_argument('--num-epochs', type=int, default=90,
81112
help='number of epochs')
82-
train.add_argument('--run-epochs', type=int, default=-1,
113+
train.add_argument('--run-epochs', type=int, default=-1,
83114
help='number of epochs to run in single run')
84115
train.add_argument('--lr', type=float, default=0.1,
85116
help='initial learning rate')
@@ -134,7 +165,8 @@ def get_epoch_size(args, kv):
134165

135166
def get_lr_scheduler(args):
136167
def multistep_schedule(x):
137-
lr = args.lr * (args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
168+
lr = args.lr * \
169+
(args.lr_factor ** (len(list(filter(lambda step: step <= x, args.lr_steps)))))
138170
warmup_coeff = min(1, x / args.warmup_epochs)
139171
return warmup_coeff * lr
140172

@@ -164,33 +196,49 @@ def cosine_schedule(x):
164196

165197

166198
def load_model(args, model):
167-
if args.load is None:
168-
return False
169-
model.load_parameters(args.load)
170-
logging.info('Loaded model {}'.format(args.load))
171-
return True
199+
file = list(glob.glob(
200+
f"{args.workspace}/{args.model_prefix}_*.params"))
201+
if len(file) == 0:
202+
return 0
172203

204+
file = [x for x in sorted(file) if "best.params" not in x][-1]
173205

174-
def save_checkpoint(net, epoch, top1, best_acc, model_prefix, save_frequency, kvstore):
206+
epoch = re.match(f".*{args.model_prefix}_([0-9]*)\.params", file)
207+
if epoch is None:
208+
return 0
209+
210+
epoch = int(epoch.group(1))
211+
model.load_parameters(file)
212+
logging.info('Loaded model {}'.format(file))
213+
return epoch
214+
215+
216+
def save_checkpoint(net, epoch, top1, best_acc, model_prefix, workspace, save_frequency, kvstore, force_save=False):
175217
if model_prefix is None or save_frequency == 0 or ('horovod' in kvstore and hvd.rank() != 0):
176218
return
177-
if save_frequency > 0 and (epoch + 1) % save_frequency == 0:
219+
if (save_frequency > 0 and (epoch + 1) % save_frequency == 0) or force_save:
178220
fname = '{}_{:04}.params'.format(model_prefix, epoch)
221+
fname = os.path.join(workspace, fname)
179222
net.save_parameters(fname)
180-
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
223+
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
224+
epoch, fname, top1))
225+
181226
if top1 > best_acc:
182-
fname = '{}_best.params'.format(model_prefix)
227+
fname = os.path.join(workspace, f'{model_prefix}_best.params')
183228
net.save_parameters(fname)
184-
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(epoch, fname, top1))
229+
logging.info('[Epoch {}] Saving checkpoint to {} with Accuracy: {:.4f}'.format(
230+
epoch, fname, top1))
185231

186232

187233
def model_pred(args, model, image):
188234
from imagenet_classes import classes
189-
output = model(image.reshape(-1, *image.shape))[0].softmax().as_in_context(mx.cpu())
235+
output = model(image.reshape(-1, *image.shape)
236+
)[0].softmax().as_in_context(mx.cpu())
190237
top = output.argsort(is_ascend=False)[:10]
191238
for i, ind in enumerate(top):
192239
ind = int(ind.asscalar())
193-
logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1, output[ind].asscalar() * 100, classes[ind]))
240+
logging.info('{:2d}. {:5.2f}% -> {}'.format(i + 1,
241+
output[ind].asscalar() * 100, classes[ind]))
194242

195243

196244
def reduce_metrics(args, metrics, kvstore):
@@ -214,7 +262,8 @@ def model_score(args, net, val_data, metric, kvstore):
214262

215263
val_data.reset()
216264

217-
total_batch_size = val_data.batch_size * val_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
265+
total_batch_size = val_data.batch_size * val_data._num_gpus * \
266+
(hvd.size() if 'horovod' in kvstore else 1)
218267

219268
durations = []
220269
tic = time.time()
@@ -225,9 +274,11 @@ def model_score(args, net, val_data, metric, kvstore):
225274
o.wait_to_read()
226275

227276
data = [b.data[0] for b in batches]
228-
label = [b.label[0][:len(b.data[0]) - b.pad] for b in batches if len(b.data[0]) != b.pad]
277+
label = [b.label[0][:len(b.data[0]) - b.pad]
278+
for b in batches if len(b.data[0]) != b.pad]
229279
outputs = [net(X) for X, b in zip(data, batches)]
230-
outputs = [o[:len(b.data[0]) - b.pad] for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
280+
outputs = [o[:len(b.data[0]) - b.pad]
281+
for o, b in zip(outputs, batches) if len(b.data[0]) != b.pad]
231282
metric.update(label, outputs)
232283

233284
durations.append(time.time() - tic)
@@ -263,21 +314,24 @@ def model_fit(args, net, train_data, eval_metric, optimizer,
263314
loss_metric = ScalarMetric()
264315

265316
if 'horovod' in kvstore:
266-
trainer = hvd.DistributedTrainer(net.collect_params(), optimizer, optimizer_params)
317+
trainer = hvd.DistributedTrainer(
318+
net.collect_params(), optimizer, optimizer_params)
267319
else:
268320
trainer = gluon.Trainer(net.collect_params(), optimizer, optimizer_params,
269321
kvstore=kv, update_on_kvstore=False)
270322

271323
if args.amp:
272324
amp.init_trainer(trainer)
273-
325+
326+
partition_handler = PartitionSignalHandler(1)
274327

275328
sparse_label_loss = (args.label_smoothing == 0 and args.mixup == 0)
276329
loss = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
277330
loss.hybridize(static_shape=True, static_alloc=True)
278331

279332
local_batch_size = train_data.batch_size
280-
total_batch_size = local_batch_size * train_data._num_gpus * (hvd.size() if 'horovod' in kvstore else 1)
333+
total_batch_size = local_batch_size * train_data._num_gpus * \
334+
(hvd.size() if 'horovod' in kvstore else 1)
281335
durations = []
282336

283337
epoch_size = get_epoch_size(args, kv)
@@ -287,16 +341,21 @@ def transform_data(images, labels):
287341
if args.mixup != 0:
288342
coeffs = mx.nd.array(np.random.beta(args.mixup, args.mixup, size=images.shape[0])).as_in_context(
289343
images.context)
290-
image_coeffs = coeffs.astype(images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
291-
ret_images = image_coeffs * images + (1 - image_coeffs) * images[::-1]
344+
image_coeffs = coeffs.astype(
345+
images.dtype, copy=False).reshape(*coeffs.shape, 1, 1, 1)
346+
ret_images = image_coeffs * images + \
347+
(1 - image_coeffs) * images[::-1]
292348

293-
ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
349+
ret_labels = label_smoothing(
350+
labels, args.num_classes, args.label_smoothing)
294351
label_coeffs = coeffs.reshape(*coeffs.shape, 1)
295-
ret_labels = label_coeffs * ret_labels + (1 - label_coeffs) * ret_labels[::-1]
352+
ret_labels = label_coeffs * ret_labels + \
353+
(1 - label_coeffs) * ret_labels[::-1]
296354
else:
297355
ret_images = images
298356
if not sparse_label_loss:
299-
ret_labels = label_smoothing(labels, args.num_classes, args.label_smoothing)
357+
ret_labels = label_smoothing(
358+
labels, args.num_classes, args.label_smoothing)
300359
else:
301360
ret_labels = labels
302361

@@ -315,76 +374,87 @@ def transform_data(images, labels):
315374

316375
logging.info('Starting epoch {}'.format(epoch))
317376
outputs = []
318-
for i, batches in enumerate(train_data):
319-
# synchronize to previous iteration
320-
#for o in outputs:
321-
# o.wait_to_read()
322-
323-
trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
324-
325-
data = [b.data[0] for b in batches]
326-
label = [b.label[0].as_in_context(b.data[0].context) for b in batches]
327-
orig_label = label
328-
329-
data, label = zip(*starmap(transform_data, zip(data, label)))
330-
331-
outputs = []
332-
Ls = []
333-
with ag.record():
334-
for x, y in zip(data, label):
335-
z = net(x)
336-
L = loss(z, y)
337-
# store the loss and do backward after we have done forward
338-
# on all GPUs for better speed on multiple GPUs.
339-
Ls.append(L)
340-
outputs.append(z)
341-
342-
if args.amp:
343-
with amp.scale_loss(Ls, trainer) as scaled_loss:
344-
ag.backward(scaled_loss)
377+
if not partition_handler.should_end():
378+
for i, batches in enumerate(train_data):
379+
# synchronize to previous iteration
380+
# for o in outputs:
381+
# o.wait_to_read()
382+
383+
trainer.set_learning_rate(lr_scheduler(epoch + i / epoch_size))
384+
385+
data = [b.data[0] for b in batches]
386+
label = [b.label[0].as_in_context(
387+
b.data[0].context) for b in batches]
388+
orig_label = label
389+
390+
data, label = zip(*starmap(transform_data, zip(data, label)))
391+
392+
outputs = []
393+
Ls = []
394+
with ag.record():
395+
for x, y in zip(data, label):
396+
z = net(x)
397+
L = loss(z, y)
398+
# store the loss and do backward after we have done forward
399+
# on all GPUs for better speed on multiple GPUs.
400+
Ls.append(L)
401+
outputs.append(z)
402+
403+
if args.amp:
404+
with amp.scale_loss(Ls, trainer) as scaled_loss:
405+
ag.backward(scaled_loss)
406+
else:
407+
ag.backward(Ls)
408+
409+
if 'horovod' in kvstore:
410+
trainer.step(local_batch_size)
345411
else:
346-
ag.backward(Ls)
347-
348-
if 'horovod' in kvstore:
349-
trainer.step(local_batch_size)
350-
else:
351-
trainer.step(total_batch_size)
412+
trainer.step(total_batch_size)
352413

353-
loss_metric.update(..., np.mean([l.asnumpy() for l in Ls]).item())
414+
loss_metric.update(..., np.mean(
415+
[l.asnumpy() for l in Ls]).item())
354416

355-
if args.disp_batches and not (i + 1) % args.disp_batches:
356-
dllogger_it_data = {
357-
'train.loss': loss_metric.get()[1],
358-
'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
359-
'train.lr': trainer.learning_rate
360-
}
361-
dllogger.log((epoch, i), data=dllogger_it_data)
417+
if args.disp_batches and not (i + 1) % args.disp_batches:
418+
dllogger_it_data = {
419+
'train.loss': loss_metric.get()[1],
420+
'train.ips': args.disp_batches * total_batch_size / (time.time() - btic),
421+
'train.lr': trainer.learning_rate
422+
}
423+
dllogger.log((epoch, i), data=dllogger_it_data)
362424

363-
loss_metric.reset_local()
364-
btic = time.time()
425+
loss_metric.reset_local()
426+
btic = time.time()
365427

366-
durations.append(time.time() - tic)
367-
tic = time.time()
428+
durations.append(time.time() - tic)
429+
tic = time.time()
368430

369431
durations = durations[min(len(durations) // 10, 100):]
370432
dllogger_epoch_data = {
371433
'train.loss': loss_metric.get_global()[1],
372434
'train.ips': total_batch_size / np.mean(durations)
373435
}
436+
437+
should_break = partition_handler.sync()
374438
if args.mode == 'train_val':
375439
logging.info('Validating epoch {}'.format(epoch))
376-
score, duration_stats, _ = model_score(args, net, eval_data, eval_metric, kvstore)
440+
score, duration_stats, _ = model_score(
441+
args, net, eval_data, eval_metric, kvstore)
377442

378443
dllogger_epoch_data.update(
379-
starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score))
444+
starmap(lambda key, val: (
445+
'val.{}'.format(key), val), zip(*score))
380446
)
381447
dllogger_epoch_data.update(
382-
starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
448+
starmap(lambda key, val: ('val.{}'.format(key), val),
449+
duration_stats.items())
383450
)
384451

385452
score = dict(zip(*score))
386453
accuracy = score.get('accuracy', -1)
387-
save_checkpoint(net, epoch, accuracy, best_accuracy, model_prefix, args.save_frequency, kvstore)
454+
save_checkpoint(net, epoch, accuracy, best_accuracy,
455+
model_prefix, args.workspace,
456+
args.save_frequency, kvstore,
457+
force_save=should_break)
388458
best_accuracy = max(best_accuracy, accuracy)
389459
global_metrics.update_dict(dllogger_epoch_data)
390460
dllogger.log(step=(epoch,), data=dllogger_epoch_data)
@@ -446,7 +516,8 @@ def fit(args, model, data_loader):
446516
tic = time.time()
447517
return
448518

449-
if not load_model(args, model):
519+
start_epoch = load_model(args, model)
520+
if start_epoch == 0:
450521
# all initializers should be specified in the model definition.
451522
# if not, this will raise an error
452523
model.initialize(mx.init.Initializer())
@@ -516,7 +587,7 @@ def fit(args, model, data_loader):
516587
args,
517588
model,
518589
train,
519-
begin_epoch=args.begin_epoch,
590+
begin_epoch=start_epoch,
520591
num_epoch=args.num_epochs,
521592
run_epoch=args.run_epochs,
522593
eval_data=val,
@@ -531,15 +602,19 @@ def fit(args, model, data_loader):
531602
)
532603
elif args.mode == 'val':
533604
for epoch in range(args.num_epochs): # loop for benchmarking
534-
score, duration_stats, durations = model_score(args, model, val, eval_metrics, args.kv_store)
535-
dllogger_data = dict(starmap(lambda key, val: ('val.{}'.format(key), val), zip(*score)))
605+
score, duration_stats, durations = model_score(
606+
args, model, val, eval_metrics, args.kv_store)
607+
dllogger_data = dict(starmap(lambda key, val: (
608+
'val.{}'.format(key), val), zip(*score)))
536609
dllogger_data.update(
537-
starmap(lambda key, val: ('val.{}'.format(key), val), duration_stats.items())
610+
starmap(lambda key, val: ('val.{}'.format(key), val),
611+
duration_stats.items())
538612
)
539613
global_metrics.update_dict(dllogger_data)
540614
for percentile in [50, 90, 95, 99, 100]:
541615
metric_name = 'val.latency_{}'.format(percentile)
542-
dllogger_data[metric_name] = np.percentile(durations, percentile)
616+
dllogger_data[metric_name] = np.percentile(
617+
durations, percentile)
543618
global_metrics.update_metric(metric_name, durations)
544619
dllogger.log(step=(epoch,), data=dllogger_data)
545620
else:

0 commit comments

Comments
 (0)