Skip to content

Commit 05a32ca

Browse files
authored
fix check for apex FusedAdam
1 parent cea5de1 commit 05a32ca

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
2222
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
2323
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
24-
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
24+
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer, FP16_FUSED_SUPPORTED_OPTIMIZERS, is_fp16_fused_supported_optimizer
2525
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
2626
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
2727
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
@@ -714,8 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
714714
initial_dynamic_scale = self.initial_dynamic_scale()
715715
dynamic_loss_args = self.dynamic_loss_scale_args()
716716
clip_grad = self.gradient_clipping()
717-
if isinstance(optimizer,
718-
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
717+
if is_fp16_fused_supported_optimizer(optimizer):
719718
if self.dynamic_loss_scale():
720719
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
721720
timers = self.timers if self.wall_clock_breakdown() else None

0 commit comments

Comments
 (0)