Skip to content

Commit 9d27982

Browse files
authored
Fix "use_amp" errors in recipes (#2872)
1 parent 9078e55 commit 9d27982

14 files changed

Lines changed: 249 additions & 592 deletions

File tree

.github/workflows/pre-commit.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ jobs:
1212
- uses: actions/checkout@v2
1313
- uses: actions/setup-python@v2
1414
with:
15-
python-version: '3.8'
16-
- uses: pre-commit/action@v2.0.3
15+
python-version: '3.12'
16+
- uses: pre-commit/action@v3.0.1

.github/workflows/pythonapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
run: |
3131
pip install uv
3232
uv pip install --system ctc-segmentation sacrebleu # ctc-segmentation is funky with uv due to their oldest-supported-numpy dependency
33-
uv pip install --system -r requirements.txt torch==2.6.0+cpu torchaudio==2.6.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu k2==1.24.4.dev20250307+cpu.torch2.6.0 --find-links https://k2-fsa.github.io/k2/cpu.html kaldilm==1.15.1 gensim==4.3.2 bitsandbytes==0.45.3 scikit-learn==1.6.1
33+
uv pip install --system -r requirements.txt torch==2.6.0+cpu torchaudio==2.6.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu k2==1.24.4.dev20250307+cpu.torch2.6.0 --find-links https://k2-fsa.github.io/k2/cpu.html gensim==4.3.2 bitsandbytes==0.45.3 scikit-learn==1.6.1
3434
uv pip install --system --editable . --no-deps # already installed pinned deps from requirements.txt, we're good
3535
- name: Install sox
3636
run: |

.github/workflows/verify-docs-gen.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
run: |
2525
pip install uv
2626
uv pip install --system sphinx>=7.4.1
27-
uv pip install --system -r requirements.txt -r docs/docs-requirements.txt torch==2.6.0+cpu torchaudio==2.6.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu kaldilm==1.15.1
27+
uv pip install --system -r requirements.txt -r docs/docs-requirements.txt torch==2.6.0+cpu torchaudio==2.6.0+cpu --extra-index-url https://download.pytorch.org/whl/cpu
2828
uv pip install --system --editable . --no-deps # already installed pinned deps from requirements.txt, we're good
2929
- name: Generate docs
3030
run: |

conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def pytest_generate_tests(metafunc):
1616
"speechbrain/lobes/models/flair",
1717
"speechbrain/lobes/models/spacy",
1818
"speechbrain/alignment/ctc_segmentation.py",
19+
"speechbrain/lm/arpa.py",
1920
]
2021
try:
2122
import numba # noqa: F401

recipes/Aishell1Mix/separation/train.py

Lines changed: 31 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
import speechbrain as sb
3636
import speechbrain.nnet.schedulers as schedulers
37-
from speechbrain.core import AMPConfig
3837
from speechbrain.utils.distributed import run_on_main
3938
from speechbrain.utils.logger import get_logger
4039

@@ -112,8 +111,6 @@ def compute_objectives(self, predictions, targets):
112111

113112
def fit_batch(self, batch):
114113
"""Trains one batch"""
115-
amp = AMPConfig.from_name(self.precision)
116-
should_step = (self.step % self.grad_accumulation_factor) == 0
117114

118115
# Unpacking batch list
119116
mixture = batch.mix_sig
@@ -126,78 +123,39 @@ def fit_batch(self, batch):
126123
if self.hparams.num_spks == 3:
127124
targets.append(batch.s3_sig)
128125

129-
with self.no_sync(not should_step):
130-
if self.use_amp:
131-
with torch.autocast(
132-
dtype=amp.dtype,
133-
device_type=torch.device(self.device).type,
134-
):
135-
predictions, targets = self.compute_forward(
136-
mixture, targets, sb.Stage.TRAIN, noise
137-
)
138-
loss = self.compute_objectives(predictions, targets)
139-
140-
# hard threshold the easy dataitems
141-
if self.hparams.threshold_byloss:
142-
th = self.hparams.threshold
143-
loss_to_keep = loss[loss > th]
144-
if loss_to_keep.nelement() > 0:
145-
loss = loss_to_keep.mean()
146-
else:
147-
loss = loss.mean()
148-
149-
if (
150-
loss < self.hparams.loss_upper_lim and loss.nelement() > 0
151-
): # the fix for computational problems
152-
self.scaler.scale(loss).backward()
153-
if self.hparams.clip_grad_norm >= 0:
154-
self.scaler.unscale_(self.optimizer)
155-
torch.nn.utils.clip_grad_norm_(
156-
self.modules.parameters(),
157-
self.hparams.clip_grad_norm,
158-
)
159-
self.scaler.step(self.optimizer)
160-
self.scaler.update()
161-
else:
162-
self.nonfinite_count += 1
163-
logger.info(
164-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
165-
self.nonfinite_count
166-
)
167-
)
168-
loss.data = torch.tensor(0.0).to(self.device)
126+
with self.training_ctx:
127+
predictions, targets = self.compute_forward(
128+
mixture, targets, sb.Stage.TRAIN, noise
129+
)
130+
loss = self.compute_objectives(predictions, targets)
131+
132+
# hard threshold the easy dataitems
133+
if self.hparams.threshold_byloss:
134+
th = self.hparams.threshold
135+
loss_to_keep = loss[loss > th]
136+
if loss_to_keep.nelement() > 0:
137+
loss = loss_to_keep.mean()
169138
else:
170-
predictions, targets = self.compute_forward(
171-
mixture, targets, sb.Stage.TRAIN, noise
139+
loss = loss.mean()
140+
141+
if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
142+
self.scaler.scale(loss).backward()
143+
if self.hparams.clip_grad_norm >= 0:
144+
self.scaler.unscale_(self.optimizer)
145+
torch.nn.utils.clip_grad_norm_(
146+
self.modules.parameters(),
147+
self.hparams.clip_grad_norm,
172148
)
173-
loss = self.compute_objectives(predictions, targets)
174-
175-
if self.hparams.threshold_byloss:
176-
th = self.hparams.threshold
177-
loss_to_keep = loss[loss > th]
178-
if loss_to_keep.nelement() > 0:
179-
loss = loss_to_keep.mean()
180-
else:
181-
loss = loss.mean()
182-
183-
if (
184-
loss < self.hparams.loss_upper_lim and loss.nelement() > 0
185-
): # the fix for computational problems
186-
loss.backward()
187-
if self.hparams.clip_grad_norm >= 0:
188-
torch.nn.utils.clip_grad_norm_(
189-
self.modules.parameters(),
190-
self.hparams.clip_grad_norm,
191-
)
192-
self.optimizer.step()
193-
else:
194-
self.nonfinite_count += 1
195-
logger.info(
196-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
197-
self.nonfinite_count
198-
)
199-
)
200-
loss.data = torch.tensor(0.0).to(self.device)
149+
self.scaler.step(self.optimizer)
150+
self.scaler.update()
151+
else:
152+
self.nonfinite_count += 1
153+
logger.info(
154+
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
155+
self.nonfinite_count
156+
)
157+
)
158+
loss.data = torch.tensor(0.0).to(self.device)
201159
self.optimizer.zero_grad()
202160

203161
return loss.detach().cpu()

recipes/BinauralWSJ0Mix/separation/train.py

Lines changed: 31 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
import speechbrain as sb
3838
import speechbrain.nnet.schedulers as schedulers
39-
from speechbrain.core import AMPConfig
4039
from speechbrain.processing.features import STFT, spectral_magnitude
4140
from speechbrain.utils.distributed import run_on_main
4241
from speechbrain.utils.logger import get_logger
@@ -200,8 +199,6 @@ def compute_objectives(self, predictions, targets):
200199

201200
def fit_batch(self, batch):
202201
"""Trains one batch"""
203-
amp = AMPConfig.from_name(self.precision)
204-
should_step = (self.step % self.grad_accumulation_factor) == 0
205202

206203
# Unpacking batch list
207204
mixture = batch.mix_sig
@@ -214,78 +211,39 @@ def fit_batch(self, batch):
214211
if "noise" in self.hparams.experiment_name:
215212
noise = batch.noise_sig[0]
216213

217-
with self.no_sync(not should_step):
218-
if self.use_amp:
219-
with torch.autocast(
220-
dtype=amp.dtype,
221-
device_type=torch.device(self.device).type,
222-
):
223-
predictions, targets = self.compute_forward(
224-
mixture, targets, sb.Stage.TRAIN, noise
225-
)
226-
loss = self.compute_objectives(predictions, targets)
227-
228-
# hard threshold the easy dataitems
229-
if self.hparams.threshold_byloss:
230-
th = self.hparams.threshold
231-
loss = loss[loss > th]
232-
if loss.nelement() > 0:
233-
loss = loss.mean()
234-
else:
235-
loss = loss.mean()
236-
237-
if (
238-
loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
239-
): # the fix for computational problems
240-
self.scaler.scale(loss).backward()
241-
if self.hparams.clip_grad_norm >= 0:
242-
self.scaler.unscale_(self.optimizer)
243-
torch.nn.utils.clip_grad_norm_(
244-
self.modules.parameters(),
245-
self.hparams.clip_grad_norm,
246-
)
247-
self.scaler.step(self.optimizer)
248-
self.scaler.update()
249-
else:
250-
self.nonfinite_count += 1
251-
logger.info(
252-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
253-
self.nonfinite_count
254-
)
255-
)
256-
loss.data = torch.tensor(0.0).to(self.device)
257-
else:
258-
predictions, targets = self.compute_forward(
259-
mixture, targets, sb.Stage.TRAIN, noise
260-
)
261-
loss = self.compute_objectives(predictions, targets)
214+
with self.training_ctx:
215+
predictions, targets = self.compute_forward(
216+
mixture, targets, sb.Stage.TRAIN, noise
217+
)
218+
loss = self.compute_objectives(predictions, targets)
262219

263-
if self.hparams.threshold_byloss:
264-
th = self.hparams.threshold
265-
loss = loss[loss > th]
266-
if loss.nelement() > 0:
267-
loss = loss.mean()
268-
else:
220+
# hard threshold the easy dataitems
221+
if self.hparams.threshold_byloss:
222+
th = self.hparams.threshold
223+
loss = loss[loss > th]
224+
if loss.nelement() > 0:
269225
loss = loss.mean()
270-
271-
if (
272-
loss.nelement() > 0 and loss < self.hparams.loss_upper_lim
273-
): # the fix for computational problems
274-
loss.backward()
275-
if self.hparams.clip_grad_norm >= 0:
276-
torch.nn.utils.clip_grad_norm_(
277-
self.modules.parameters(),
278-
self.hparams.clip_grad_norm,
279-
)
280-
self.optimizer.step()
281-
else:
282-
self.nonfinite_count += 1
283-
logger.info(
284-
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
285-
self.nonfinite_count
286-
)
287-
)
288-
loss.data = torch.tensor(0.0).to(self.device)
226+
else:
227+
loss = loss.mean()
228+
229+
if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
230+
self.scaler.scale(loss).backward()
231+
if self.hparams.clip_grad_norm >= 0:
232+
self.scaler.unscale_(self.optimizer)
233+
torch.nn.utils.clip_grad_norm_(
234+
self.modules.parameters(),
235+
self.hparams.clip_grad_norm,
236+
)
237+
self.scaler.step(self.optimizer)
238+
self.scaler.update()
239+
else:
240+
self.nonfinite_count += 1
241+
logger.info(
242+
"infinite loss or empty loss! it happened {} times so far - skipping this batch".format(
243+
self.nonfinite_count
244+
)
245+
)
246+
loss.data = torch.tensor(0.0).to(self.device)
289247
self.optimizer.zero_grad()
290248

291249
return loss.detach().cpu()

0 commit comments

Comments
 (0)