forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·580 lines (465 loc) · 19.7 KB
/
utils.py
File metadata and controls
executable file
·580 lines (465 loc) · 19.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
'''
Copyright 2019 The Microsoft DeepSpeed Team
Copyright NVIDIA/Megatron
Helper functions and classes from multiple sources.
'''
import os
from math import ceil
from math import floor
from bisect import bisect_left, bisect_right
import torch
import torch.distributed as dist
from torch._six import inf
import torch.distributed as dist
from deepspeed.utils import logger
from numpy import prod
def ensure_directory_exists(filename):
"""Create the directory path to ``filename`` if it does not already exist.
Args:
filename (str): A file path.
"""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok=True)
def set_random_seed(seed):
import numpy
import random
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def move_to_device(item, device):
"""
Move tensor onto device. Works on individual tensors, and tensors contained/nested in lists, tuples, and dicts.
Parameters:
item: tensor to move or (possibly nested) container of tensors to move.
device: target device
Returns:
None
"""
if torch.is_tensor(item):
return item.to(device)
elif isinstance(item, list):
return [move_to_device(v, device) for v in item]
elif isinstance(item, tuple):
return tuple([move_to_device(v, device) for v in item])
elif isinstance(item, dict):
return {k: move_to_device(v, device) for k, v in item.items()}
else:
return item
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False):
self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
if param_groups:
for group in param_groups:
for param in group:
self.params.append(param)
def check_using_norm(self, norm_group, reduce_overflow=True):
#TODO: I don't think reduce_overflow is needed if mpu is None
overflow = -1 in norm_group
if self.mpu is not None:
overflow_gpu = torch.cuda.ByteTensor([overflow])
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
elif reduce_overflow:
cuda_overflow = torch.cuda.FloatTensor([overflow])
dist.all_reduce(cuda_overflow, op=torch.distributed.ReduceOp.MAX)
dist.barrier()
overflow = cuda_overflow[0].item()
return bool(overflow)
def check(self, param_groups=None):
params = []
if param_groups is None:
params = self.params
else:
assert param_groups is not None, \
"self.params and param_groups both cannot be none"
for group in param_groups:
for param in group:
params.append(param)
return self.has_overflow(params)
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
return True
return False
def has_overflow(self, params):
overflow = self.has_overflow_serial(params)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu = torch.cuda.ByteTensor([overflow])
#torch.distributed.all_reduce(overflow_gpu,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
if self.zero_reduce_scatter:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=torch.distributed.group.WORLD)
elif self.mpu is not None:
torch.distributed.all_reduce(overflow_gpu,
op=torch.distributed.ReduceOp.MAX,
group=self.mpu.get_model_parallel_group())
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, i):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
return True
return False
def _handle_overflow(cpu_sum, x, i):
import math
rank = torch.distributed.get_rank()
if rank == 0:
t_i = -1
for v_i, v in enumerate(x.data.contiguous().view(-1)):
if not math.isfinite(float(v)):
t_i = v_i
break
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)
def get_grad_norm(parameters, norm_type=2, mpu=None):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place. Taken from Nvidia Megatron.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
for p in parameters:
if mpu is not None:
if (mpu.get_model_parallel_rank() == 0
) or is_model_parallel_parameter(p):
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def get_weight_norm(parameters, norm_type=2, mpu=None):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place. Taken from Nvidia Megatron.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(p.data.abs().max() for p in parameters)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all GPUs.
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
for p in parameters:
if mpu is not None:
if (mpu.get_model_parallel_rank() == 0
) or is_model_parallel_parameter(p):
try:
param_norm = float(torch.norm(p, norm_type, dtype=torch.float32))
except TypeError as err:
param_norm = float(torch.norm(p.float(), norm_type))
#param_norm = p.data.float().norm(norm_type)
total_norm += param_norm**norm_type
else:
try:
param_norm = float(torch.norm(p, norm_type, dtype=torch.float32))
except TypeError as err:
param_norm = float(torch.norm(p.float(), norm_type))
#param_norm = p.data.float().norm(norm_type)
total_norm += param_norm**norm_type
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
if mpu is not None:
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def is_model_parallel_parameter(p):
return hasattr(p, 'model_parallel') and p.model_parallel
def prefix_sum_inc(weights):
""" Compute an inclusive prefix sum.
Example:
>>> prefix_sum_inc([3,4,5])
[3, 7, 12]
"""
weights_ = [w for w in weights]
for x in range(1, len(weights_)):
weights_[x] += weights_[x - 1]
return weights_
def partition_uniform(num_items, num_parts):
parts = [0] * (num_parts + 1)
# First check for the trivial edge case
if num_items <= num_parts:
for p in range(num_parts + 1):
parts[p] = min(p, num_items)
return parts
chunksize = floor(num_items / num_parts)
for p in range(num_parts):
parts[p] = min(chunksize * p, num_items)
parts[num_parts] = num_items
return parts
def _lprobe(weights, num_parts, bottleneck):
num_items = len(weights)
total_weight = weights[-1]
# initialize partitioning
parts = [0] * (num_parts + 1)
for p in range(1, num_parts + 1):
parts[p] = num_items
bsum = bottleneck # running sum of target weight for pth partition
chunksize = num_items // num_parts
step = chunksize
for p in range(1, num_parts):
# Jump to the next bucket
while (step < num_items) and (weights[step] < bsum):
step += chunksize
# Find the end index of partition p
parts[p] = bisect_left(weights,
bsum,
lo=step - chunksize,
hi=min(step,
num_items))
# Nothing more to partition, return early
if parts[p] == num_items:
# See if the current partition is overweight.
part_size = weights[-1] - weights[parts[p - 1]]
return parts, part_size < bottleneck
# Next partition target
bsum = weights[parts[p] - 1] + bottleneck
return parts, bsum >= total_weight
def _rb_partition_balanced(weights, num_parts, eps):
total_weight = weights[-1]
lower = total_weight / num_parts # best case heaviest partition
upper = total_weight # worst case heaviest partition
# Do a binary search for the best partitioning
while upper > lower + eps:
mid = lower + ((upper - lower) / 2)
parts, success = _lprobe(weights, num_parts, mid)
if success:
upper = mid
else:
lower = mid + eps
return upper
def partition_balanced(weights, num_parts, eps=1e-3):
num_items = len(weights)
# First check for the trivial edge case
if num_items <= num_parts:
return partition_uniform(num_items, num_parts)
weights_ = prefix_sum_inc(weights)
# Find the smallest bottleneck (weight of heaviest partition)
bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)
# Now compute that partitioning
parts, success = _lprobe(weights_, num_parts, bottleneck)
assert success
return parts
class PartitionedTensor:
def __init__(self, tensor, group, partition_meta=None):
super().__init__()
self.group = group
self.num_parts = dist.get_world_size(group=self.group)
self.rank = dist.get_rank(group=self.group)
self.orig_size = list(tensor.size())
self.orig_device = tensor.device
self.local_data, self.partition = self._partition_tensor(tensor)
@classmethod
def from_meta(cls, meta, local_part, group, device='cuda'):
assert meta.dtype == torch.long
dummy = torch.ones(dist.get_world_size(group=group))
part_obj = cls(tensor=dummy, group=group)
meta = meta.tolist()
# [N, list0, ..., listN-1]
part_obj.orig_size = meta[1:(1 + meta[0])]
meta = meta[1 + meta[0]:]
part_obj.orig_device = device
part_obj.local_data = local_part.detach()
part_obj.group = group
# Partition is encoded like the rowptr of a CSR matrix:
# [num_parts, rank, 0, part_1, ..., part_num_parts]
# TODO: support shuffle between different partition granularities
assert part_obj.num_parts == meta[0]
assert part_obj.rank == meta[1]
part_obj.partition = meta[2:] # length num_parts+1
return part_obj
def _partition_tensor(self, tensor):
partition = partition_uniform(num_items=tensor.numel(), num_parts=self.num_parts)
start = partition[self.rank]
length = partition[self.rank + 1] - start
tensor_part = tensor.detach().contiguous().view(-1).narrow(
0,
start=start,
length=length).clone()
return tensor_part, partition
def full(self, device=None):
if device is None:
device = self.orig_device
# Allocate the full tensor as a flat buffer.
full_numel = prod(self.full_size())
flat_tensor = torch.zeros([full_numel],
dtype=self.local_data.dtype,
device=device)
# Prepare all-gather buffer
partition_tensors = []
for part_id in range(self.num_parts):
part_size = self.partition[part_id + 1] - self.partition[part_id]
buf = flat_tensor.narrow(0, start=self.partition[part_id], length=part_size)
if part_id == self.rank:
buf.copy_(self.local_data)
partition_tensors.append(buf)
# Collect the full tensor
dist.all_gather(partition_tensors,
partition_tensors[self.rank],
group=self.group)
for i in range(len(partition_tensors)):
partition_tensors[i].data = torch.zeros(1)
partition_tensors[i] = None
return flat_tensor.view(self.full_size()).clone().detach()
def to_meta(self):
"""Returns a torch.LongTensor that encodes partitioning information.
Can be used along with ``data()`` to serialize a ``PartitionedTensor`` for
communication.
Returns:
torch.LongTensor: a tensor encoding the meta-information for the partitioning
"""
meta = []
meta.append(len(self.orig_size))
meta += list(self.orig_size)
meta.append(self.num_parts)
meta.append(self.rank)
meta += self.partition
return torch.LongTensor(data=meta).to(self.orig_device)
def data(self):
return self.local_data
def local_size(self):
return self.local_data.size()
def full_size(self):
return self.orig_size
mem_alloced = 0
mem_cached = 0
def memory_status(msg, print_rank=-1, reset_max=False):
global mem_alloced, mem_cached
rank = dist.get_rank()
if print_rank != -1 and rank != print_rank:
return
torch.cuda.synchronize()
if reset_max:
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
new_alloced = torch.cuda.memory_allocated()
new_cached = torch.cuda.memory_cached()
delta_alloced = new_alloced - mem_alloced
delta_cached = new_cached - mem_cached
mem_cached = new_cached
mem_alloced = new_alloced
max_alloced = torch.cuda.max_memory_allocated()
max_cached = torch.cuda.max_memory_cached()
# convert to GB for printing
new_alloced /= 1024**3
new_cached /= 1024**3
delta_alloced /= 1024**3
delta_cached /= 1024**3
max_alloced /= 1024**3
max_cached /= 1024**3
print(
f'RANK={rank} MEMSTATS',
msg,
f'device={torch.cuda.current_device()} '
f'current alloc={new_alloced:0.4f}GB (delta={delta_alloced:0.4f}GB max={max_alloced:0.4f}GB) '
f'current cache={new_cached:0.4f}GB (delta={delta_cached:0.4f}GB max={max_cached:0.4f}GB)'
)
def see_memory_usage(message):
return
if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
return
# Print message except when distributed but not rank 0
logger.info(message)
logger.info(
f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
CA {round(torch.cuda.memory_cached() / (1024 * 1024 * 1024),2)} GB \
Max_CA {round(torch.cuda.max_memory_cached() / (1024 * 1024 * 1024))} GB ")
def call_to_str(base, *args, **kwargs):
"""Construct a string representation of a call.
Args:
base (str): name of the call
args (tuple, optional): args to ``base``
kwargs (dict, optional): kwargs supplied to ``base``
Returns:
str: A string representation of base(*args, **kwargs)
"""
name = f'{base}('
if args:
name += ', '.join(repr(arg) for arg in args)
if kwargs:
name += ', '
if kwargs:
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
name += ')'
return name