We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2cd968f commit 20ea07aCopy full SHA for 20ea07a
1 file changed
Megatron-LM-v1.1.5-ZeRO3/pretrain_gpt2.py
@@ -34,10 +34,9 @@ def model_provider():
34
"""Build the model."""
35
36
print_rank_0('building GPT2 model ...')
37
- with deepspeed.zero.InitContext(data_parallel_group=mpu.get_data_parallel_group(),
38
- zero_modules=True,
39
- remote_device=get_args().remote_device,
40
- enabled=get_args().zero_stage==3):
+ with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
+ remote_device=get_args().remote_device,
+ enabled=get_args().zero_stage==3):
41
model = GPT2Model(num_tokentypes=0, parallel_output=True)
42
43
return model
0 commit comments