77import sys
88from typing import Optional
99from enum import Enum
10- from deepspeed .runtime .config_utils import get_scalar_param , DeepSpeedConfigModel
10+ from deepspeed .runtime .config_utils import get_scalar_param , pp_int , DeepSpeedConfigModel
1111from deepspeed .utils import logger
1212from .offload_config import DeepSpeedZeroOffloadParamConfig , DeepSpeedZeroOffloadOptimizerConfig , OffloadDeviceEnum
1313
@@ -67,6 +67,7 @@ def get_zero_config(param_dict):
6767
6868
6969class ZeroStageEnum (int , Enum ):
70+ """ Enum class for possible zero stages """
7071 disabled = 0
7172 optimizer_states = 1
7273 gradients = 2
@@ -75,21 +76,86 @@ class ZeroStageEnum(int, Enum):
7576
7677
7778class DeepSpeedZeroConfig (DeepSpeedConfigModel ):
78- stage : ZeroStageEnum = ZeroStageEnum .disabled
79+ """
80+ Sets parameters for ZeRO optimizations.
81+ """
82+
83+ stage : ZeroStageEnum = 0
84+ """
85+ Chooses different stages of ZeRO Optimizer. Stage 0, 1, 2, and 3 refer
86+ to disabled, optimizer state partitioning, and optimizer+gradient state
87+ partitioning, and optimizer+gradient+parameter partitioning, respectively.
88+ """
89+
7990 contiguous_gradients : bool = True
91+ """
92+ Copies the gradients to a contiguous buffer as they are produced. Avoids
93+ memory fragmentation during backward pass.
94+ """
95+
8096 reduce_scatter : bool = True
81- reduce_bucket_size : int = Field (5e8 , ge = 0 )
97+ """
98+ Uses reduce or reduce scatter instead of allreduce to average gradients
99+ """
100+
101+ reduce_bucket_size : int = Field (pp_int (5e8 ), ge = 0 )
102+ """
103+ Number of elements reduced/allreduced at a time. Limits the memory required
104+ for the allgather for large model sizes
105+ """
106+
82107 allgather_partitions : bool = True
83- allgather_bucket_size : int = Field (5e8 , ge = 0 )
84- overlap_comm : bool = None # None for dynamic default value
108+ """
109+ Chooses between allgather collective or a series of broadcast collectives
110+ to gather updated parameters from all the GPUs at the end of each step
111+ """
112+
113+ allgather_bucket_size : int = Field (pp_int (5e8 ), ge = 0 )
114+ """
115+ Number of elements allgathered at a time. Limits the memory required for
116+ the allgather for large model sizes
117+ """
118+
119+ overlap_comm : bool = None # None for dynamic default value (see validator `overlap_comm_valid` below)
120+ """
121+ Attempts to overlap the reduction of the gradients with backward computation
122+ """
123+
85124 load_from_fp32_weights : bool = True
125+ """
126+ Boolean indicating whether to initialize fp32 master weights from fp32
127+ copies in checkpoint (no precision loss) or from model's fp16 copies (with
128+ precision loss). This can be used to initialize optimizer state even when
129+ checkpoint is missing optimizer state.
130+ """
86131
87132 elastic_checkpoint : bool = False
133+ """
134+ Enable loading checkpoint that was saved by job with different GPU count.
135+ No longer supported.
136+ """
88137
89- # Offload Specific Parameters
90138 offload_param : Optional [DeepSpeedZeroOffloadParamConfig ] = None
139+ """
140+ Enable offloading of model parameters to CPU or NVMe. This frees up GPU
141+ memory for larger models or batch sizes. Valid only with stage 3. Expects a
142+ dictionary containing values for `DeepSpeedZeroOffloadParamConfig`_.
143+ """
144+
91145 offload_optimizer : Optional [DeepSpeedZeroOffloadOptimizerConfig ] = None
92- sub_group_size : int = Field (1e9 , ge = 0 )
146+ """
147+ Enable offloading of optimizer state to CPU or NVMe, and optimizer
148+ computation to CPU. This frees up GPU memory for larger models or batch
149+ sizes. Valid for ZeRO stage 1, 2, 3. Expects a dictionary containing values
150+ for `DeepSpeedZeroOffloadOptimizerConfig`_.
151+ """
152+
153+ sub_group_size : int = Field (pp_int (1e9 ), ge = 0 )
154+ """
155+ Tile size for parameter processing to fit massive models (with trillions of
156+ parameters). Used by ZeRO3-Offload and ZeRO-Infinity
157+ """
158+
93159 cpu_offload_param : bool = Field (
94160 None ,
95161 deprecated = True ,
@@ -98,12 +164,16 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
98164 lambda val : DeepSpeedZeroOffloadParamConfig (device = OffloadDeviceEnum .cpu )
99165 if val else None ),
100166 )
167+ """ Deprecated, please use ``offload_param`` """
168+
101169 cpu_offload_use_pin_memory : bool = Field (
102170 None ,
103171 deprecated = True ,
104172 new_param = "offload_param or offload_optimizer" ,
105173 set_new_param = False ,
106174 )
175+ """ Deprecated, please use ``offload_param`` or ``offload_optimizer`` """
176+
107177 cpu_offload : bool = Field (
108178 None ,
109179 deprecated = True ,
@@ -112,29 +182,90 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
112182 lambda val : DeepSpeedZeroOffloadOptimizerConfig (device = OffloadDeviceEnum .cpu )
113183 if val else None ),
114184 )
185+ """ Deprecated, please use ``offload_optimizer`` """
186+
187+ prefetch_bucket_size : int = Field (pp_int (5e7 ),
188+ ge = 0 ,
189+ alias = "stage3_prefetch_bucket_size" )
190+ """
191+ Maximum number of parameter elements to fetch ahead of use. Used by ZeRO3,
192+ ZeRO3-Offload, ZeRO-Infinity, and ZeRO-Inference.
193+ """
115194
116- # Stage3 Specific Parameters
117- prefetch_bucket_size : int = Field (5e7 , ge = 0 , alias = "stage3_prefetch_bucket_size" )
118- param_persistence_threshold : int = Field (1e5 ,
195+ param_persistence_threshold : int = Field (pp_int (1e5 ),
119196 ge = 0 ,
120197 alias = "stage3_param_persistence_threshold" )
121- model_persistence_threshold : int = Field (sys .maxsize ,
198+ """
199+ Do not partition parameters smaller than this threshold. Smaller values use
200+ less memory, but can greatly increase communication (especially
201+ latency-bound messages).
202+ """
203+
204+ model_persistence_threshold : int = Field (pp_int (sys .maxsize ,
205+ "sys.maxsize" ),
122206 ge = 0 ,
123207 alias = "stage3_model_persistence_threshold" )
124- max_live_parameters : int = Field (1e9 , ge = 0 , alias = "stage3_max_live_parameters" )
125- max_reuse_distance : int = Field (1e9 , ge = 0 , alias = "stage3_max_reuse_distance" )
208+ """
209+ Maximum number of parameter elements that can be persisted in GPU and not
210+ partitioned. This imposes an upper bound on the number of unpartitioned
211+ parameters resulting from param_persistence_threshold setting. Used by
212+ ZeRO3-Offload, ZeRO-Infinity and ZeRO-Inference.
213+ """
214+
215+ max_live_parameters : int = Field (pp_int (1e9 ),
216+ ge = 0 ,
217+ alias = "stage3_max_live_parameters" )
218+ """
219+ The maximum number of parameters resident per GPU before releasing. Smaller
220+ values use less memory, but perform more communication.
221+ """
222+
223+ max_reuse_distance : int = Field (pp_int (1e9 ), ge = 0 , alias = "stage3_max_reuse_distance" )
224+ """
225+ Do not release a parameter if it will be reused within this threshold of
226+ parameters. Smaller values use less memory, but perform more communication.
227+ """
228+
126229 gather_16bit_weights_on_model_save : bool = Field (
127230 False ,
128231 alias = "stage3_gather_16bit_weights_on_model_save" )
232+ """
233+ Consolidate the weights before saving the model by ``save_16bit_model()``.
234+ Since the weights are partitioned across GPUs, they aren’t part of
235+ ``state_dict``, so this function automatically gathers the weights when
236+ this option is enabled and then saves the fp16 model weights.
237+ """
238+
129239 stage3_gather_fp16_weights_on_model_save : bool = Field (
130240 False ,
131241 deprecated = True ,
132242 new_param = "gather_16bit_weights_on_model_save" )
243+ """ Deprecated, please use ``gather_16bit_weights_on_model_save`` """
133244
134245 ignore_unused_parameters : bool = True
246+ """
247+ Unused parameters in modules may be unexpected in static networks, but
248+ could be normal in dynamic networks. This controls whether or not training
249+ should terminate with an error message when unused parameters are detected.
250+ This is set to ``False`` by default, which means unused parameters are
251+ ignored and training continues. Now is just used in stage 2.
252+ """
253+
135254 legacy_stage1 : bool = False
255+ """
256+ For backward-compatibility enable old ZeRO stage 1 implementation. Use at
257+ your own risk, will be deprecated soon.
258+ """
259+
136260 round_robin_gradients : bool = False
261+ """
262+ Stage 1 and 2 optimization for CPU offloading that parallelizes gradient
263+ copying to CPU memory among ranks by fine-grained gradient partitioning.
264+ Performance benefit grows with gradient accumulation steps (more copying
265+ between optimizer steps) or GPU count (increased parallelism).
266+ """
137267
268+ # Validators
138269 @validator ("overlap_comm" )
139270 def overlap_comm_valid (cls , field_value , values ):
140271 if field_value is None :
0 commit comments