Skip to content

Commit c7c2063

Browse files
authored
revert
1 parent c814fca commit c7c2063

1 file changed

Lines changed: 5 additions & 16 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
2222
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
2323
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
24-
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
24+
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer, FP16_FUSED_SUPPORTED_OPTIMIZERS, is_fp16_fused_supported_optimizer
2525
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
2626
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
2727
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
@@ -397,9 +397,6 @@ def zero_gather_fp16_weights_on_model_save(self):
397397
def fp16_enabled(self):
398398
return self._config.fp16_enabled
399399

400-
def precision(self):
401-
return self._config.precision
402-
403400
def amp_enabled(self):
404401
return self._config.amp_enabled
405402

@@ -572,18 +569,14 @@ def is_replicated(p):
572569

573570
for p in self.module.parameters():
574571
if torch.is_tensor(p) and is_replicated(p):
575-
if self.precision() == torch.bfloat16:
576-
p = p.float()
577572
dist.broadcast(p,
578573
self.broadcast_src_rank,
579574
group=self.data_parallel_group)
580-
if self.precision() == torch.bfloat16:
581-
p = p.bfloat16()
582575

583576
def _configure_distributed_model(self, model):
584577
self.module = model
585578
if self.fp16_enabled():
586-
self.module.to(self.precision())
579+
self.module.half()
587580

588581
if not self.dont_change_device:
589582
self.module.to(self.device)
@@ -721,8 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
721714
initial_dynamic_scale = self.initial_dynamic_scale()
722715
dynamic_loss_args = self.dynamic_loss_scale_args()
723716
clip_grad = self.gradient_clipping()
724-
if isinstance(optimizer,
725-
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
717+
if is_fp16_fused_supported_optimizer(optimizer):
726718
if self.dynamic_loss_scale():
727719
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
728720
timers = self.timers if self.wall_clock_breakdown() else None
@@ -780,8 +772,7 @@ def _configure_zero_optimizer(self, optimizer):
780772
max_elements_per_comm=self.zero_reduce_bucket_size(),
781773
dp_process_group=self.data_parallel_group,
782774
elastic_checkpoint=self.zero_elastic_checkpoint(),
783-
mpu=self.mpu,
784-
precision=self.precision())
775+
mpu=self.mpu)
785776
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
786777
optimizer = FP16_DeepSpeedZeroOptimizer(
787778
optimizer,
@@ -800,8 +791,7 @@ def _configure_zero_optimizer(self, optimizer):
800791
mpu=self.mpu,
801792
postscale_gradients=self.postscale_gradients(),
802793
gradient_predivide_factor=self.gradient_predivide_factor(),
803-
gradient_accumulation_steps=self.gradient_accumulation_steps(),
804-
precision=self.precision())
794+
gradient_accumulation_steps=self.gradient_accumulation_steps())
805795
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
806796
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
807797
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -989,7 +979,6 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
989979

990980
# Communicate only at gradient accumulation boundaries
991981
elif self.is_gradient_accumulation_boundary():
992-
# TODO: communication in fp16 / fp32
993982
if self.zero_optimization_stage(
994983
) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter():
995984
self.optimizer.reduce_scatter_gradients(

0 commit comments

Comments
 (0)