Skip to content

Commit 599258f

Browse files
samyamjeffratjruwaseShaden SmithShadenSmith
authored
ZeRO 3 Offload (deepspeedai#834)
* Squash stage3 v1 (deepspeedai#146) Co-authored-by: Samyam <samyamr@microsoft.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Samyam Rajbhandari <samyamr@microsoft.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com> Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com> Co-authored-by: eltonzheng <eltonz@microsoft.com> * Fix correctness bug (deepspeedai#147) * formatting fix (deepspeedai#150) * stage3 bugfix (API) update and simplified FP16 Z3 tests (deepspeedai#151) * fp16 Z3 API update and bugfix * revert debug change * ZeRO-3 detach and race condition bugfixes (deepspeedai#149) * trying out ZeRO-3 race condition fix * CUDA sync instead of stream * reduction stream sync * remove commented code * Fix optimizer state_dict KeyError (deepspeedai#148) Co-authored-by: Jeff Rasley <jerasley@microsoft.com> * fix for smaller SGS sizes, ensures each grad is backed by unique tensors (deepspeedai#152) * Simplifying the logic for getting averaged gradients (deepspeedai#153) * skip for now * Z3 Docs redux (deepspeedai#154) * removing some TODOs and commented code (deepspeedai#155) * New Z3 defaults (deepspeedai#156) Co-authored-by: Jeff Rasley <jerasley@microsoft.com> * formatting * megatron external params Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Shaden Smith <Shaden.Smith@microsoft.com> Co-authored-by: Shaden Smith <ShadenTSmith@gmail.com> Co-authored-by: eltonzheng <eltonz@microsoft.com>
1 parent ba33e86 commit 599258f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+5747
-321
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ jobs:
4848
- name: Unit tests
4949
run: |
5050
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
51-
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose -x tests/unit/
51+
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --durations=0 --forked --verbose tests/unit/

deepspeed/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .utils import log_dist
1717
from .utils.distributed import init_distributed
1818

19+
from .runtime import zero
20+
1921
from .pipe import PipelineModule
2022

2123
from .git_version_info import version, git_hash, git_branch

deepspeed/launcher/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def main(args=None):
304304
# encode world info as base64 to make it easier to pass via command line
305305
world_info_base64 = encode_world_info(active_resources)
306306

307-
multi_node_exec = len(active_resources) > 1
307+
multi_node_exec = True # len(active_resources) > 1
308308

309309
if multi_node_exec and not shutil.which('pdsh'):
310310
raise RuntimeError("pdsh is not installed, unable to proceed")

deepspeed/ops/adam/cpu_adam.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,41 +10,6 @@
1010

1111

1212
class DeepSpeedCPUAdam(torch.optim.Optimizer):
13-
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
14-
15-
- Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
16-
- AdamW: FIXING WEIGHT DECAY REGULARIZATION IN ADAM (https://arxiv.org/abs/1711.05101v1)
17-
18-
DeepSpeed CPU Adam(W) provides between 5x to 7x speedu over torch.optim.adam(W).
19-
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
20-
reside on the CPU memory.
21-
22-
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
23-
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
24-
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
25-
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
26-
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
27-
28-
For calling step function, there are two options available: (1) update optimizer's states and (2) update
29-
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
30-
option can bring 30% higher throughput than the doing the copy separately using option one.
31-
32-
33-
Arguments:
34-
model_params (iterable): iterable of parameters to optimize or dicts defining
35-
parameter groups.
36-
lr (float, optional): learning rate. (default: 1e-3)
37-
betas (Tuple[float, float], optional): coefficients used for computing
38-
running averages of gradient and its square. (default: (0.9, 0.999))
39-
eps (float, optional): term added to the denominator to improve
40-
numerical stability. (default: 1e-8)
41-
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
42-
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
43-
algorithm from the paper `On the Convergence of Adam and Beyond`_
44-
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
45-
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
46-
"""
47-
4813
optimizer_id = 0
4914

5015
def __init__(self,
@@ -57,6 +22,47 @@ def __init__(self,
5722
weight_decay=0,
5823
amsgrad=False,
5924
adamw_mode=True):
25+
"""Fast vectorized implementation of two variations of Adam optimizer on CPU:
26+
27+
* Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980);
28+
* AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101)
29+
30+
DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W).
31+
In order to apply this optimizer, the model requires to have its master parameter (in FP32)
32+
reside on the CPU memory.
33+
34+
To train on a hetrogeneous system, such as coordinating CPU and GPU, DeepSpeed offers
35+
the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory,
36+
with minimal impact on training througput. DeepSpeedCPUAdam plays an important role to minimize
37+
the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial
38+
(https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology.
39+
40+
For calling step function, there are two options available: (1) update optimizer's states and (2) update
41+
optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second
42+
option can bring 30% higher throughput than the doing the copy separately using option one.
43+
44+
45+
.. note::
46+
We recommend using our `config
47+
<https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
48+
to allow :meth:`deepspeed.initialize` to build this optimizer
49+
for you.
50+
51+
52+
Arguments:
53+
model_params (iterable): iterable of parameters to optimize or dicts defining
54+
parameter groups.
55+
lr (float, optional): learning rate. (default: 1e-3)
56+
betas (Tuple[float, float], optional): coefficients used for computing
57+
running averages of gradient and its square. (default: (0.9, 0.999))
58+
eps (float, optional): term added to the denominator to improve
59+
numerical stability. (default: 1e-8)
60+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
61+
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
62+
algorithm from the paper `On the Convergence of Adam and Beyond`_
63+
(default: False) NOT SUPPORTED in DeepSpeed CPUAdam!
64+
adamw_mode: select between Adam and AdamW implementations (default: AdamW)
65+
"""
6066

6167
default_args = dict(lr=lr,
6268
betas=betas,
@@ -86,6 +92,24 @@ def __setstate__(self, state):
8692

8793
@torch.no_grad()
8894
def step(self, closure=None, fp16_param_groups=None):
95+
"""Update the model parameters.
96+
97+
.. note::
98+
This method will be called internally by ZeRO-Offload. DeepSpeed
99+
users should still use ``engine.step()`` as shown in the
100+
`Getting Started
101+
<https://www.deepspeed.ai/getting-started/#training>`_ guide.
102+
103+
Args:
104+
closure (callable, optional): closure to compute the loss.
105+
Defaults to ``None``.
106+
fp16_param_groups: FP16 GPU parameters to update. Performing the
107+
copy here reduces communication time. Defaults to ``None``.
108+
109+
Returns:
110+
loss: if ``closure`` is provided. Otherwise ``None``.
111+
"""
112+
89113
loss = None
90114
if closure is not None:
91115
with torch.enable_grad():
@@ -100,7 +124,7 @@ def step(self, closure=None, fp16_param_groups=None):
100124
state = self.state[p]
101125
# State initialization
102126
if len(state) == 0:
103-
print(f'group {group_id} param {param_id} = {p.numel()}')
127+
#print(f'group {group_id} param {param_id} = {p.numel()}')
104128
state['step'] = 0
105129
# gradient momentums
106130
state['exp_avg'] = torch.zeros_like(p.data,

0 commit comments

Comments
 (0)