|
12 | 12 | import math |
13 | 13 |
|
14 | 14 | import torch |
15 | | -from packaging import version |
16 | 15 |
|
17 | 16 |
|
18 | 17 | def compute_amplitude(waveforms, lengths=None, amp_type="avg", scale="linear"): |
@@ -280,26 +279,10 @@ def convolve1d( |
280 | 279 | kernel = torch.cat((after_index, zeros, before_index), dim=-1) |
281 | 280 |
|
282 | 281 | # Multiply in frequency domain to convolve in time domain |
283 | | - if version.parse(torch.__version__) > version.parse("1.6.0"): |
284 | | - import torch.fft as fft |
| 282 | + import torch.fft as fft |
285 | 283 |
|
286 | | - result = fft.rfft(waveform) * fft.rfft(kernel) |
287 | | - convolved = fft.irfft(result, n=waveform.size(-1)) |
288 | | - else: |
289 | | - f_signal = torch.rfft(waveform, 1) |
290 | | - f_kernel = torch.rfft(kernel, 1) |
291 | | - sig_real, sig_imag = f_signal.unbind(-1) |
292 | | - ker_real, ker_imag = f_kernel.unbind(-1) |
293 | | - f_result = torch.stack( |
294 | | - [ |
295 | | - sig_real * ker_real - sig_imag * ker_imag, |
296 | | - sig_real * ker_imag + sig_imag * ker_real, |
297 | | - ], |
298 | | - dim=-1, |
299 | | - ) |
300 | | - convolved = torch.irfft( |
301 | | - f_result, 1, signal_sizes=[waveform.size(-1)] |
302 | | - ) |
| 284 | + result = fft.rfft(waveform) * fft.rfft(kernel) |
| 285 | + convolved = fft.irfft(result, n=waveform.size(-1)) |
303 | 286 |
|
304 | 287 | # Use the implementation given by torch, which should be efficient on GPU |
305 | 288 | else: |
|
0 commit comments