forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
executable file
·84 lines (66 loc) · 2.86 KB
/
utils.py
File metadata and controls
executable file
·84 lines (66 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
from typing import List
import torch
from deepspeed import comm as dist
from deepspeed.utils import logger
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.ops.adam import FusedAdam
from deepspeed.utils.nvtx import instrument_w_nvtx
from deepspeed.accelerator import get_accelerator
def _initialize_parameter_parallel_groups(parameter_parallel_size=None):
data_parallel_size = int(dist.get_world_size())
parameter_parallel_size = parameter_parallel_size or data_parallel_size
logger.info("data_parallel_size: %s, parameter_parallel_size: %s", data_parallel_size, parameter_parallel_size)
assert data_parallel_size % parameter_parallel_size == 0, \
'world size should be divisible by parameter parallel size'
rank = dist.get_rank()
my_group = None
for i in range(data_parallel_size // parameter_parallel_size):
ranks = range(i * parameter_parallel_size, (i + 1) * parameter_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
my_group = group
return my_group
class ZeRORuntimeException(Exception):
pass
ZERO_SUPPORTED_OPTIMIZERS = [torch.optim.Adam, torch.optim.AdamW, FusedAdam, DeepSpeedCPUAdam]
# Add apex FusedAdam to supported list if apex is installed
try:
import apex
if hasattr(apex, 'optimizers') and hasattr(apex.optimizers, 'FusedAdam'):
ZERO_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
except ImportError:
pass
def is_zero_supported_optimizer(optimizer):
if dist.get_rank() == 0:
logger.info(f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}')
return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS
def get_lst_from_rank0(lst: List[int]) -> None:
"""
NOTE: creates both communication and synchronization overhead so should be used
sparingly
"""
lst_tensor = torch.tensor(
lst if dist.get_rank() == 0 else [-1] * len(lst),
dtype=int,
# device=get_accelerator().current_device_name(),
device=torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])),
requires_grad=False,
)
dist.broadcast(lst_tensor, src=0, async_op=False)
return list(lst_tensor.cpu().numpy())
@instrument_w_nvtx
def assert_ints_same_as_other_ranks(ints: List[int]) -> None:
"""
NOTE: creates both communication and synchronization overhead so should be
used sparingly
takes a list of ints from each rank and ensures that they are the same
across ranks, throwing an exception if they are not.
"""
rank0_ints = get_lst_from_rank0(ints)
if ints != rank0_ints:
raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: "
f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}")