Skip to content

Commit a6748a4

Browse files
committed
If you try to reduce the loss lower than the loss scale - loss scaler will raise an error
1 parent 24026e5 commit a6748a4

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

deepspeed/runtime/fp16/loss_scaler.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
1717
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
1818

19-
import torch
20-
2119
INITIAL_LOSS_SCALE = 'init_scale'
2220
SCALE_WINDOW = 'scale_window'
2321
DELAYED_SHIFT = 'delayed_shift'
@@ -107,7 +105,8 @@ def __init__(self,
107105
scale_window=1000,
108106
min_scale=1,
109107
delayed_shift=1,
110-
consecutive_hysteresis=False):
108+
consecutive_hysteresis=False,
109+
raise_error_at_min_scale=True):
111110
super(DynamicLossScaler, self).__init__(init_scale)
112111
self.cur_iter = 0
113112
self.last_overflow_iter = -1
@@ -117,13 +116,13 @@ def __init__(self,
117116
self.delayed_shift = delayed_shift
118117
self.cur_hysteresis = delayed_shift
119118
self.consecutive_hysteresis = consecutive_hysteresis
119+
self.raise_error_at_min_scale = raise_error_at_min_scale
120120

121121
# `params` is a list / generator of torch.Variable
122122
def has_overflow_serial(self, params):
123123
for p in params:
124124
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
125125
return True
126-
127126
return False
128127

129128
# `x` is a torch.Tensor
@@ -152,6 +151,9 @@ def update_scale(self, overflow):
152151
if overflow:
153152
# self.cur_scale /= self.scale_factor
154153
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
154+
if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
155+
raise Exception("Current loss scale already at minimum - cannot decrease scale anymore. Exiting "
156+
"run.")
155157
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
156158
else:
157159
self.cur_hysteresis -= 1

0 commit comments

Comments
 (0)