Skip to content

Commit cea5de1

Browse files
authored
fix check for apex fused adam
1 parent 0959801 commit cea5de1

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

deepspeed/runtime/fp16/fused_optimizer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,35 @@
1313
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
1414
from deepspeed.utils import logger, log_dist
1515

16+
from ...ops.adam import FusedAdam
17+
FP16_FUSED_SUPPORTED_OPTIMIZERS = [
18+
FusedAdam,
19+
]
20+
21+
# Add apex FusedAdam to supported list if apex is installed
22+
try:
23+
import apex
24+
FP16_FUSED_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
25+
except ImportError:
26+
pass
27+
28+
29+
def is_fp16_fused_supported_optimizer(optimizer):
30+
"""Is an optimizer compatible with ``FP16_Optimizer``?
31+
Args:
32+
optimizer (torch.optim.Optimizer): Optimizer to query.
33+
Returns:
34+
bool: True if ``optimizer`` is compatible with ``FP16_Optimizer``.
35+
"""
36+
from deepspeed.runtime.config import ONEBIT_ADAM_OPTIMIZER
37+
if isinstance(optimizer, tuple(FP16_FUSED_SUPPORTED_OPTIMIZERS)):
38+
return True
39+
if optimizer.__class__.__name__.lower() == ONEBIT_ADAM_OPTIMIZER.lower():
40+
return True
41+
return False
42+
43+
44+
1645

1746
class FP16_Optimizer(object):
1847
"""

0 commit comments

Comments
 (0)