3434
3535import speechbrain as sb
3636import speechbrain .nnet .schedulers as schedulers
37- from speechbrain .core import AMPConfig
3837from speechbrain .utils .distributed import run_on_main
3938from speechbrain .utils .logger import get_logger
4039
@@ -112,8 +111,6 @@ def compute_objectives(self, predictions, targets):
112111
113112 def fit_batch (self , batch ):
114113 """Trains one batch"""
115- amp = AMPConfig .from_name (self .precision )
116- should_step = (self .step % self .grad_accumulation_factor ) == 0
117114
118115 # Unpacking batch list
119116 mixture = batch .mix_sig
@@ -126,78 +123,39 @@ def fit_batch(self, batch):
126123 if self .hparams .num_spks == 3 :
127124 targets .append (batch .s3_sig )
128125
129- with self .no_sync (not should_step ):
130- if self .use_amp :
131- with torch .autocast (
132- dtype = amp .dtype ,
133- device_type = torch .device (self .device ).type ,
134- ):
135- predictions , targets = self .compute_forward (
136- mixture , targets , sb .Stage .TRAIN , noise
137- )
138- loss = self .compute_objectives (predictions , targets )
139-
140- # hard threshold the easy dataitems
141- if self .hparams .threshold_byloss :
142- th = self .hparams .threshold
143- loss_to_keep = loss [loss > th ]
144- if loss_to_keep .nelement () > 0 :
145- loss = loss_to_keep .mean ()
146- else :
147- loss = loss .mean ()
148-
149- if (
150- loss < self .hparams .loss_upper_lim and loss .nelement () > 0
151- ): # the fix for computational problems
152- self .scaler .scale (loss ).backward ()
153- if self .hparams .clip_grad_norm >= 0 :
154- self .scaler .unscale_ (self .optimizer )
155- torch .nn .utils .clip_grad_norm_ (
156- self .modules .parameters (),
157- self .hparams .clip_grad_norm ,
158- )
159- self .scaler .step (self .optimizer )
160- self .scaler .update ()
161- else :
162- self .nonfinite_count += 1
163- logger .info (
164- "infinite loss or empty loss! it happened {} times so far - skipping this batch" .format (
165- self .nonfinite_count
166- )
167- )
168- loss .data = torch .tensor (0.0 ).to (self .device )
126+ with self .training_ctx :
127+ predictions , targets = self .compute_forward (
128+ mixture , targets , sb .Stage .TRAIN , noise
129+ )
130+ loss = self .compute_objectives (predictions , targets )
131+
132+ # hard threshold the easy dataitems
133+ if self .hparams .threshold_byloss :
134+ th = self .hparams .threshold
135+ loss_to_keep = loss [loss > th ]
136+ if loss_to_keep .nelement () > 0 :
137+ loss = loss_to_keep .mean ()
169138 else :
170- predictions , targets = self .compute_forward (
171- mixture , targets , sb .Stage .TRAIN , noise
139+ loss = loss .mean ()
140+
141+ if loss < self .hparams .loss_upper_lim and loss .nelement () > 0 :
142+ self .scaler .scale (loss ).backward ()
143+ if self .hparams .clip_grad_norm >= 0 :
144+ self .scaler .unscale_ (self .optimizer )
145+ torch .nn .utils .clip_grad_norm_ (
146+ self .modules .parameters (),
147+ self .hparams .clip_grad_norm ,
172148 )
173- loss = self .compute_objectives (predictions , targets )
174-
175- if self .hparams .threshold_byloss :
176- th = self .hparams .threshold
177- loss_to_keep = loss [loss > th ]
178- if loss_to_keep .nelement () > 0 :
179- loss = loss_to_keep .mean ()
180- else :
181- loss = loss .mean ()
182-
183- if (
184- loss < self .hparams .loss_upper_lim and loss .nelement () > 0
185- ): # the fix for computational problems
186- loss .backward ()
187- if self .hparams .clip_grad_norm >= 0 :
188- torch .nn .utils .clip_grad_norm_ (
189- self .modules .parameters (),
190- self .hparams .clip_grad_norm ,
191- )
192- self .optimizer .step ()
193- else :
194- self .nonfinite_count += 1
195- logger .info (
196- "infinite loss or empty loss! it happened {} times so far - skipping this batch" .format (
197- self .nonfinite_count
198- )
199- )
200- loss .data = torch .tensor (0.0 ).to (self .device )
149+ self .scaler .step (self .optimizer )
150+ self .scaler .update ()
151+ else :
152+ self .nonfinite_count += 1
153+ logger .info (
154+ "infinite loss or empty loss! it happened {} times so far - skipping this batch" .format (
155+ self .nonfinite_count
156+ )
157+ )
158+ loss .data = torch .tensor (0.0 ).to (self .device )
201159 self .optimizer .zero_grad ()
202160
203161 return loss .detach ().cpu ()
0 commit comments