-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathbeamform_multimic.py
More file actions
50 lines (39 loc) · 1.27 KB
/
beamform_multimic.py
File metadata and controls
50 lines (39 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Beamformer for multi-mic processing.
Authors
* Nauman Dawalatabad
"""
import torch
from speechbrain.processing.features import ISTFT, STFT
from speechbrain.processing.multi_mic import Covariance, DelaySum, GccPhat
class DelaySum_Beamformer(torch.nn.Module):
"""Generate beamformed signal from multi-mic data using DelaySum beamforming.
Arguments
---------
sampling_rate : int (default: 16000)
Sampling rate of audio signals.
"""
def __init__(self, sampling_rate=16000):
super().__init__()
self.fs = sampling_rate
self.stft = STFT(sample_rate=self.fs)
self.cov = Covariance()
self.gccphat = GccPhat()
self.delaysum = DelaySum()
self.istft = ISTFT(sample_rate=self.fs)
def forward(self, mics_signals):
"""Returns beamformed signal using multi-mic data.
Arguments
---------
mics_signals : torch.Tensor
Set of audio signals to be transformed.
Returns
-------
sig : torch.Tensor
"""
with torch.no_grad():
Xs = self.stft(mics_signals)
XXs = self.cov(Xs)
tdoas = self.gccphat(XXs)
Ys_ds = self.delaysum(Xs, tdoas)
sig = self.istft(Ys_ds)
return sig