File tree Expand file tree Collapse file tree 1 file changed +29
-0
lines changed
Expand file tree Collapse file tree 1 file changed +29
-0
lines changed Original file line number Diff line number Diff line change 1313from deepspeed .runtime .fp16 .loss_scaler import INITIAL_LOSS_SCALE , SCALE_WINDOW , MIN_LOSS_SCALE
1414from deepspeed .utils import logger , log_dist
1515
16+ from ...ops .adam import FusedAdam
17+ FP16_FUSED_SUPPORTED_OPTIMIZERS = [
18+ FusedAdam ,
19+ ]
20+
21+ # Add apex FusedAdam to supported list if apex is installed
22+ try :
23+ import apex
24+ FP16_FUSED_SUPPORTED_OPTIMIZERS .append (apex .optimizers .FusedAdam )
25+ except ImportError :
26+ pass
27+
28+
29+ def is_fp16_fused_supported_optimizer (optimizer ):
30+ """Is an optimizer compatible with ``FP16_Optimizer``?
31+ Args:
32+ optimizer (torch.optim.Optimizer): Optimizer to query.
33+ Returns:
34+ bool: True if ``optimizer`` is compatible with ``FP16_Optimizer``.
35+ """
36+ from deepspeed .runtime .config import ONEBIT_ADAM_OPTIMIZER
37+ if isinstance (optimizer , tuple (FP16_FUSED_SUPPORTED_OPTIMIZERS )):
38+ return True
39+ if optimizer .__class__ .__name__ .lower () == ONEBIT_ADAM_OPTIMIZER .lower ():
40+ return True
41+ return False
42+
43+
44+
1645
1746class FP16_Optimizer (object ):
1847 """
You can’t perform that action at this time.
0 commit comments