Skip to content

Commit d48d36c

Browse files
authored
support torch>=2.9.0 (#3032)
1 parent a601051 commit d48d36c

6 files changed

Lines changed: 13 additions & 36 deletions

File tree

docs/readthedocs-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44

55
-r ../requirements.txt
66
-r docs-requirements.txt
7-
torch==2.8.0
7+
torch==2.9.0

recipes/LibriSpeech/ASR/transformer/extract_ssl_feats.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def audio_pipeline(wav):
5858
def compute_feats(uid, sig):
5959
sig = sig.to(hparams["device"]).unsqueeze(0)
6060
length = torch.ones(1, device=hparams["device"])
61-
with torch.no_grad(), torch.cuda.amp.autocast(dtype=hparams["dtype"]):
61+
with torch.no_grad(), torch.amp.autocast(
62+
hparams["device"].type, dtype=hparams["dtype"]
63+
):
6264
feats = normalizer(sig, length)
6365
feats = ssl_encoder(feats, length)
6466
return feats.squeeze(0).cpu()

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ requests>=2.20.0
1010
scipy>=1.4.1
1111
sentencepiece>=0.1.91
1212
soundfile>=0.12.1
13-
torch>=2.1.0,<2.9
14-
torchaudio>=2.1.0,<2.9
13+
torch>=2.1.0
14+
torchaudio>=2.1.0
1515
tqdm>=4.42.0
1616
transformers>=4.30.0

speechbrain/processing/multi_mic.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@
8080
"""
8181

8282
import torch
83-
from packaging import version
8483

8584
import speechbrain.processing.decomposition as eig
8685

@@ -765,11 +764,8 @@ def _gcc_phat(XXs, eps=1e-20):
765764
# Returning in the temporal domain
766765
XXs_phat = XXs_phat.transpose(2, 3)
767766

768-
if version.parse(torch.__version__) >= version.parse("1.8.0"):
769-
XXs_phat = torch.complex(XXs_phat[..., 0], XXs_phat[..., 1])
770-
xxs = torch.fft.irfft(XXs_phat, n=n_samples)
771-
else:
772-
xxs = torch.irfft(XXs_phat, signal_ndim=1, signal_sizes=[n_samples])
767+
XXs_phat = torch.complex(XXs_phat[..., 0], XXs_phat[..., 1])
768+
xxs = torch.fft.irfft(XXs_phat, n=n_samples)
773769

774770
xxs = xxs[..., XXs_idx, :]
775771

speechbrain/processing/signal_processing.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import math
1313

1414
import torch
15-
from packaging import version
1615

1716

1817
def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"):
@@ -280,26 +279,10 @@ def convolve1d(
280279
kernel = torch.cat((after_index, zeros, before_index), dim=-1)
281280

282281
# Multiply in frequency domain to convolve in time domain
283-
if version.parse(torch.__version__) > version.parse("1.6.0"):
284-
import torch.fft as fft
282+
import torch.fft as fft
285283

286-
result = fft.rfft(waveform) * fft.rfft(kernel)
287-
convolved = fft.irfft(result, n=waveform.size(-1))
288-
else:
289-
f_signal = torch.rfft(waveform, 1)
290-
f_kernel = torch.rfft(kernel, 1)
291-
sig_real, sig_imag = f_signal.unbind(-1)
292-
ker_real, ker_imag = f_kernel.unbind(-1)
293-
f_result = torch.stack(
294-
[
295-
sig_real * ker_real - sig_imag * ker_imag,
296-
sig_real * ker_imag + sig_imag * ker_real,
297-
],
298-
dim=-1,
299-
)
300-
convolved = torch.irfft(
301-
f_result, 1, signal_sizes=[waveform.size(-1)]
302-
)
284+
result = fft.rfft(waveform) * fft.rfft(kernel)
285+
convolved = fft.irfft(result, n=waveform.size(-1))
303286

304287
# Use the implementation given by torch, which should be efficient on GPU
305288
else:

speechbrain/utils/checkpoints.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,8 @@ def torch_parameter_transfer(obj, path):
276276
torch.optim.Optimizer: torch_save,
277277
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save,
278278
}
279-
if version.parse(torch.__version__) < version.parse("2.0.0"):
280-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler._LRScheduler] = torch_recovery
281-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler._LRScheduler] = torch_save
282-
else:
283-
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
284-
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
279+
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_recovery
280+
DEFAULT_SAVE_HOOKS[torch.optim.lr_scheduler.LRScheduler] = torch_save
285281

286282
if version.parse(torch.__version__) < version.parse("2.4.0"):
287283
DEFAULT_LOAD_HOOKS[torch.cuda.amp.grad_scaler.GradScaler] = torch_recovery

0 commit comments

Comments
 (0)