1616# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
1717#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
1818
19- import torch
20-
2119INITIAL_LOSS_SCALE = 'init_scale'
2220SCALE_WINDOW = 'scale_window'
2321DELAYED_SHIFT = 'delayed_shift'
@@ -107,7 +105,8 @@ def __init__(self,
107105 scale_window = 1000 ,
108106 min_scale = 1 ,
109107 delayed_shift = 1 ,
110- consecutive_hysteresis = False ):
108+ consecutive_hysteresis = False ,
109+ raise_error_at_min_scale = True ):
111110 super (DynamicLossScaler , self ).__init__ (init_scale )
112111 self .cur_iter = 0
113112 self .last_overflow_iter = - 1
@@ -117,13 +116,13 @@ def __init__(self,
117116 self .delayed_shift = delayed_shift
118117 self .cur_hysteresis = delayed_shift
119118 self .consecutive_hysteresis = consecutive_hysteresis
119+ self .raise_error_at_min_scale = raise_error_at_min_scale
120120
121121 # `params` is a list / generator of torch.Variable
122122 def has_overflow_serial (self , params ):
123123 for p in params :
124124 if p .grad is not None and self ._has_inf_or_nan (p .grad .data ):
125125 return True
126-
127126 return False
128127
129128 # `x` is a torch.Tensor
@@ -152,6 +151,9 @@ def update_scale(self, overflow):
152151 if overflow :
153152 # self.cur_scale /= self.scale_factor
154153 if self .delayed_shift == 1 or self .cur_hysteresis == 1 :
154+ if (self .cur_scale == self .min_scale ) and self .raise_error_at_min_scale :
155+ raise Exception ("Current loss scale already at minimum - cannot decrease scale anymore. Exiting "
156+ "run." )
155157 self .cur_scale = max (self .cur_scale / self .scale_factor , self .min_scale )
156158 else :
157159 self .cur_hysteresis -= 1
0 commit comments