-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Expand file tree
/
Copy pathfeatures.py
More file actions
862 lines (753 loc) · 29.8 KB
/
features.py
File metadata and controls
862 lines (753 loc) · 29.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
"""Basic feature pipelines.
Authors
* Mirco Ravanelli 2020
* Peter Plantinga 2020
* Sarthak Yadav 2020
* Sylvain de Langen 2024
"""
from dataclasses import dataclass
from functools import partial
from typing import Optional
import torch
from speechbrain.nnet.CNN import GaborConv1d
from speechbrain.nnet.normalization import PCEN
from speechbrain.nnet.pooling import GaussianLowpassPooling
from speechbrain.processing.features import (
DCT,
STFT,
ContextWindow,
Deltas,
Filterbank,
spectral_magnitude,
)
from speechbrain.processing.vocal_features import (
PERIODIC_NEIGHBORS,
compute_autocorr_features,
compute_gne,
compute_periodic_features,
compute_spectral_features,
)
from speechbrain.utils.autocast import fwd_default_precision
from speechbrain.utils.filter_analysis import FilterProperties
class Fbank(torch.nn.Module):
"""Generate features for input to the speech pipeline.
Arguments
---------
deltas : bool (default: False)
Whether or not to append derivatives and second derivatives
to the features.
context : bool (default: False)
Whether or not to append forward and backward contexts to
the features.
requires_grad : bool (default: False)
Whether to allow parameters (i.e. fbank centers and
spreads) to update during training.
sample_rate : int (default: 160000)
Sampling rate for the input waveforms.
f_min : int (default: 0)
Lowest frequency for the Mel filters.
f_max : int (default: None)
Highest frequency for the Mel filters. Note that if f_max is not
specified it will be set to sample_rate // 2.
n_fft : int (default: 400)
Number of samples to use in each stft.
n_mels : int (default: 40)
Number of Mel filters.
filter_shape : str (default: triangular)
Shape of the filters ('triangular', 'rectangular', 'gaussian').
param_change_factor : float (default: 1.0)
If freeze=False, this parameter affects the speed at which the filter
parameters (i.e., central_freqs and bands) can be changed. When high
(e.g., param_change_factor=1) the filters change a lot during training.
When low (e.g. param_change_factor=0.1) the filter parameters are more
stable during training.
param_rand_factor : float (default: 0.0)
This parameter can be used to randomly change the filter parameters
(i.e, central frequencies and bands) during training. It is thus a
sort of regularization. param_rand_factor=0 does not affect, while
param_rand_factor=0.15 allows random variations within +-15% of the
standard values of the filter parameters (e.g., if the central freq
is 100 Hz, we can randomly change it from 85 Hz to 115 Hz).
left_frames : int (default: 5)
Number of frames of left context to add.
right_frames : int (default: 5)
Number of frames of right context to add.
win_length : float (default: 25)
Length (in ms) of the sliding window used to compute the STFT.
hop_length : float (default: 10)
Length (in ms) of the hop of the sliding window used to compute
the STFT.
Example
-------
>>> import torch
>>> inputs = torch.randn([10, 16000])
>>> feature_maker = Fbank()
>>> feats = feature_maker(inputs)
>>> feats.shape
torch.Size([10, 101, 40])
"""
def __init__(
self,
deltas=False,
context=False,
requires_grad=False,
sample_rate=16000,
f_min=0,
f_max=None,
n_fft=400,
n_mels=40,
filter_shape="triangular",
param_change_factor=1.0,
param_rand_factor=0.0,
left_frames=5,
right_frames=5,
win_length=25,
hop_length=10,
):
super().__init__()
self.deltas = deltas
self.context = context
self.requires_grad = requires_grad
if f_max is None:
f_max = sample_rate // 2
self.compute_STFT = STFT(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
self.compute_fbanks = Filterbank(
sample_rate=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
freeze=not requires_grad,
filter_shape=filter_shape,
param_change_factor=param_change_factor,
param_rand_factor=param_rand_factor,
)
self.compute_deltas = Deltas(input_size=n_mels)
self.context_window = ContextWindow(
left_frames=left_frames,
right_frames=right_frames,
)
@fwd_default_precision(cast_inputs=torch.float32)
def forward(self, wav):
"""Returns a set of features generated from the input waveforms.
Arguments
---------
wav : torch.Tensor
A batch of audio signals to transform to features.
Returns
-------
fbanks : torch.Tensor
"""
STFT = self.compute_STFT(wav)
mag = spectral_magnitude(STFT)
fbanks = self.compute_fbanks(mag)
if self.deltas:
delta1 = self.compute_deltas(fbanks)
delta2 = self.compute_deltas(delta1)
fbanks = torch.cat([fbanks, delta1, delta2], dim=2)
if self.context:
fbanks = self.context_window(fbanks)
return fbanks
def get_filter_properties(self) -> FilterProperties:
# only the STFT affects the FilterProperties of the Fbank
return self.compute_STFT.get_filter_properties()
class MFCC(torch.nn.Module):
"""Generate features for input to the speech pipeline.
Arguments
---------
deltas : bool (default: True)
Whether or not to append derivatives and second derivatives
to the features.
context : bool (default: True)
Whether or not to append forward and backward contexts to
the features.
requires_grad : bool (default: False)
Whether to allow parameters (i.e. fbank centers and
spreads) to update during training.
sample_rate : int (default: 16000)
Sampling rate for the input waveforms.
f_min : int (default: 0)
Lowest frequency for the Mel filters.
f_max : int (default: None)
Highest frequency for the Mel filters. Note that if f_max is not
specified it will be set to sample_rate // 2.
n_fft : int (default: 400)
Number of samples to use in each stft.
n_mels : int (default: 23)
Number of filters to use for creating filterbank.
n_mfcc : int (default: 20)
Number of output coefficients
filter_shape : str (default 'triangular')
Shape of the filters ('triangular', 'rectangular', 'gaussian').
param_change_factor: bool (default 1.0)
If freeze=False, this parameter affects the speed at which the filter
parameters (i.e., central_freqs and bands) can be changed. When high
(e.g., param_change_factor=1) the filters change a lot during training.
When low (e.g. param_change_factor=0.1) the filter parameters are more
stable during training.
param_rand_factor: float (default 0.0)
This parameter can be used to randomly change the filter parameters
(i.e, central frequencies and bands) during training. It is thus a
sort of regularization. param_rand_factor=0 does not affect, while
param_rand_factor=0.15 allows random variations within +-15% of the
standard values of the filter parameters (e.g., if the central freq
is 100 Hz, we can randomly change it from 85 Hz to 115 Hz).
left_frames : int (default 5)
Number of frames of left context to add.
right_frames : int (default 5)
Number of frames of right context to add.
win_length : float (default: 25)
Length (in ms) of the sliding window used to compute the STFT.
hop_length : float (default: 10)
Length (in ms) of the hop of the sliding window used to compute
the STFT.
Example
-------
>>> import torch
>>> inputs = torch.randn([10, 16000])
>>> feature_maker = MFCC()
>>> feats = feature_maker(inputs)
>>> feats.shape
torch.Size([10, 101, 660])
"""
def __init__(
self,
deltas=True,
context=True,
requires_grad=False,
sample_rate=16000,
f_min=0,
f_max=None,
n_fft=400,
n_mels=23,
n_mfcc=20,
filter_shape="triangular",
param_change_factor=1.0,
param_rand_factor=0.0,
left_frames=5,
right_frames=5,
win_length=25,
hop_length=10,
):
super().__init__()
self.deltas = deltas
self.context = context
self.requires_grad = requires_grad
if f_max is None:
f_max = sample_rate // 2
self.compute_STFT = STFT(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
self.compute_fbanks = Filterbank(
sample_rate=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
freeze=not requires_grad,
filter_shape=filter_shape,
param_change_factor=param_change_factor,
param_rand_factor=param_rand_factor,
)
self.compute_dct = DCT(input_size=n_mels, n_out=n_mfcc)
self.compute_deltas = Deltas(input_size=n_mfcc)
self.context_window = ContextWindow(
left_frames=left_frames,
right_frames=right_frames,
)
@fwd_default_precision(cast_inputs=torch.float32)
def forward(self, wav):
"""Returns a set of mfccs generated from the input waveforms.
Arguments
---------
wav : torch.Tensor
A batch of audio signals to transform to features.
Returns
-------
mfccs : torch.Tensor
"""
STFT = self.compute_STFT(wav)
mag = spectral_magnitude(STFT)
fbanks = self.compute_fbanks(mag)
mfccs = self.compute_dct(fbanks)
if self.deltas:
delta1 = self.compute_deltas(mfccs)
delta2 = self.compute_deltas(delta1)
mfccs = torch.cat([mfccs, delta1, delta2], dim=2)
if self.context:
mfccs = self.context_window(mfccs)
return mfccs
class Leaf(torch.nn.Module):
"""
This class implements the LEAF audio frontend from
Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND
FOR AUDIO CLASSIFICATION", in Proc. of ICLR 2021 (https://arxiv.org/abs/2101.08596)
Arguments
---------
out_channels : int
It is the number of output channels.
window_len: float
length of filter window in milliseconds
window_stride : float
Stride factor of the filters in milliseconds
sample_rate : int,
Sampling rate of the input signals. It is only used for sinc_conv.
input_shape : tuple
Expected shape of the inputs.
in_channels : int
Expected number of input channels.
min_freq : float
Lowest possible frequency (in Hz) for a filter
max_freq : float
Highest possible frequency (in Hz) for a filter
use_pcen: bool
If True (default), a per-channel energy normalization layer is used
learnable_pcen: bool:
If True (default), the per-channel energy normalization layer is learnable
use_legacy_complex: bool
If False, torch.complex64 data type is used for gabor impulse responses
If True, computation is performed on two real-valued torch.Tensors
skip_transpose: bool
If False, uses batch x time x channel convention of speechbrain.
If True, uses batch x channel x time convention.
n_fft: int
Number of FFT bins
Example
-------
>>> inp_tensor = torch.rand([10, 8000])
>>> leaf = Leaf(
... out_channels=40, window_len=25.0, window_stride=10.0, in_channels=1
... )
>>> out_tensor = leaf(inp_tensor)
>>> out_tensor.shape
torch.Size([10, 50, 40])
"""
def __init__(
self,
out_channels,
window_len: float = 25.0,
window_stride: float = 10.0,
sample_rate: int = 16000,
input_shape=None,
in_channels=None,
min_freq=60.0,
max_freq=None,
use_pcen=True,
learnable_pcen=True,
use_legacy_complex=False,
skip_transpose=False,
n_fft=512,
):
super().__init__()
self.out_channels = out_channels
window_size = int(sample_rate * window_len // 1000 + 1)
window_stride = int(sample_rate * window_stride // 1000)
if input_shape is None and in_channels is None:
raise ValueError("Must provide one of input_shape or in_channels")
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.complex_conv = GaborConv1d(
out_channels=2 * out_channels,
in_channels=in_channels,
kernel_size=window_size,
stride=1,
padding="same",
bias=False,
n_fft=n_fft,
sample_rate=sample_rate,
min_freq=min_freq,
max_freq=max_freq,
use_legacy_complex=use_legacy_complex,
skip_transpose=True,
)
self.pooling = GaussianLowpassPooling(
in_channels=self.out_channels,
kernel_size=window_size,
stride=window_stride,
skip_transpose=True,
)
if use_pcen:
self.compression = PCEN(
self.out_channels,
alpha=0.96,
smooth_coef=0.04,
delta=2.0,
floor=1e-12,
trainable=learnable_pcen,
per_channel_smooth_coef=True,
skip_transpose=True,
)
else:
self.compression = None
self.skip_transpose = skip_transpose
@fwd_default_precision(cast_inputs=torch.float32)
def forward(self, x):
"""
Returns the learned LEAF features
Arguments
---------
x : torch.Tensor of shape (batch, time, 1) or (batch, time)
batch of input signals. 2d or 3d tensors are expected.
Returns
-------
outputs : torch.Tensor
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
unsqueeze = x.ndim == 2
if unsqueeze:
x = x.unsqueeze(1)
outputs = self.complex_conv(x)
outputs = self._squared_modulus_activation(outputs)
outputs = self.pooling(outputs)
outputs = torch.maximum(
outputs, torch.tensor(1e-5, device=outputs.device)
)
if self.compression:
outputs = self.compression(outputs)
if not self.skip_transpose:
outputs = outputs.transpose(1, -1)
return outputs
def _squared_modulus_activation(self, x):
x = x.transpose(1, 2)
output = 2 * torch.nn.functional.avg_pool1d(
x**2.0, kernel_size=2, stride=2
)
output = output.transpose(1, 2)
return output
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels."""
if len(shape) == 2:
in_channels = 1
elif len(shape) == 3:
in_channels = 1
else:
raise ValueError(
"Leaf expects 2d or 3d inputs. Got " + str(len(shape))
)
return in_channels
def upalign_value(x, to: int) -> int:
"""If `x` cannot evenly divide `to`, round it up to the next value that
can."""
assert x >= 0
if (x % to) == 0:
return x
return x + to - (x % to)
@dataclass
class StreamingFeatureWrapperContext:
"""Streaming metadata for the feature extractor. Holds some past context
frames."""
left_context: Optional[torch.Tensor]
"""Cached left frames to be inserted as left padding for the next chunk.
Initially `None` then gets updated from the last frames of the current
chunk.
See the relevant `forward` function for details."""
class StreamingFeatureWrapper(torch.nn.Module):
"""Wraps an arbitrary filter so that it can be used in a streaming fashion
(i.e. on a per-chunk basis), by remembering context and making "clever" use
of padding.
Arguments
---------
module : torch.nn.Module
The filter to wrap; e.g. a module list that constitutes a sequential
feature extraction pipeline.
The module is assumed to pad its inputs, e.g. the output of a
convolution with a stride of 1 would end up with the same frame count
as the input.
properties : FilterProperties
The effective filter properties of the provided module. This is used to
determine padding and caching.
"""
def __init__(self, module: torch.nn.Module, properties: FilterProperties):
super().__init__()
self.module = module
self.properties = properties
if self.properties.causal:
raise ValueError(
"Causal streaming feature wrapper is not yet supported"
)
if self.properties.dilation != 1:
raise ValueError(
"Dilation not yet supported in streaming feature wrapper"
)
def get_required_padding(self) -> int:
"""Computes the number of padding/context frames that need to be
injected at the past and future of the input signal in the forward pass.
"""
return upalign_value(
(self.properties.window_size - 1) // 2, self.properties.stride
)
def get_output_count_per_pad_frame(self) -> int:
"""Computes the exact number of produced frames (along the time
dimension) per input pad frame."""
return self.get_required_padding() // self.properties.stride
def get_recommended_final_chunk_count(self, frames_per_chunk: int) -> int:
"""Get the recommended number of zero chunks to inject at the end of an
input stream depending on the filter properties of the extractor.
The number of injected chunks is chosen to ensure that the filter has
output frames centered on the last input frames.
See also :meth:`~StreamingFeatureWrapper.forward`.
Arguments
---------
frames_per_chunk : int
The number of frames per chunk, i.e. the size of the time dimension
passed to :meth:`~StreamingFeatureWrapper.forward`.
Returns
-------
Recommended number of chunks.
"""
return (
upalign_value(self.get_required_padding(), frames_per_chunk)
// frames_per_chunk
)
def forward(
self,
chunk: torch.Tensor,
context: StreamingFeatureWrapperContext,
*extra_args,
**extra_kwargs,
) -> torch.Tensor:
"""Forward pass for the streaming feature wrapper.
For the first chunk, 0-padding is inserted at the past of the input.
For any chunk (including the first), some future frames get truncated
and cached to be inserted as left context for the next chunk in time.
For further explanations, see the comments in the code.
Note that due to how the padding is implemented, you may want to call
this with a chunk worth full of zeros (potentially more for filters with
large windows) at the end of your input so that the final frames have a
chance to get processed by the filter.
See :meth:`~StreamingFeatureWrapper.get_recommended_final_chunk_count`.
This is not really an issue when processing endless streams, but when
processing files, it could otherwise result in truncated outputs.
Arguments
---------
chunk : torch.Tensor
Chunk of input of shape [batch size, time]; typically a raw
waveform. Normally, in a chunkwise streaming scenario,
`time = (stride-1) * chunk_size` where `chunk_size` is the desired
**output** frame count.
context : StreamingFeatureWrapperContext
Mutable streaming context object; should be reused for subsequent
calls in the same streaming session.
*extra_args : tuple
**extra_kwargs : dict
Args to be passed to he module.
Returns
-------
torch.Tensor
Processed chunk of shape [batch size, output frames]. This shape is
equivalent to the shape of `module(chunk)`.
"""
feat_pad_size = self.get_required_padding()
num_outputs_per_pad = self.get_output_count_per_pad_frame()
# consider two audio chunks of 6 samples (for the example), where
# each sample is denoted by 1, 2, ..., 6
# so chunk 1 is 123456 and chunk 2 is 123456
if context.left_context is None:
# for the first chunk we left pad the input by two padding's worth of zeros,
# and truncate the right, so that we can pretend to have right padding and
# still consume the same amount of samples every time
#
# our first processed chunk will look like:
# 0000123456
# ^^ right padding (truncated)
# ^^^^^^ frames that some outputs are centered on
# ^^ left padding (truncated)
chunk = torch.nn.functional.pad(chunk, (feat_pad_size * 2, 0))
else:
# prepend left context
#
# for the second chunk ownwards, given the above example:
# 34 of the previous chunk becomes left padding
# 56 of the previous chunk becomes the first frames of this chunk
# thus on the second iteration (and onwards) it will look like:
# 3456123456
# ^^ right padding (truncated)
# ^^^^^^ frames that some outputs are centered on
# ^^ left padding (truncated)
chunk = torch.cat((context.left_context, chunk), 1)
# our chunk's right context will become the start of the "next processed chunk"
# plus we need left padding for that one, so make it double
context.left_context = chunk[:, -feat_pad_size * 2 :]
feats = self.module(chunk, *extra_args, **extra_kwargs)
# truncate left and right context
feats = feats[:, num_outputs_per_pad:-num_outputs_per_pad, ...]
return feats
def get_filter_properties(self) -> FilterProperties:
return self.properties
def make_streaming_context(self) -> StreamingFeatureWrapperContext:
return StreamingFeatureWrapperContext(None)
class VocalFeatures(torch.nn.Module):
"""Estimates the vocal characteristics of a signal in four categories of features:
* Autocorrelation-based
* Period-based (jitter/shimmer)
* Spectrum-based
* MFCCs
Arguments
---------
min_f0_Hz: int
The minimum allowed fundamental frequency, to reduce octave errors.
Default is 80 Hz, based on human voice standard frequency range.
max_f0_Hz: int
The maximum allowed fundamental frequency, to reduce octave errors.
Default is 300 Hz, based on human voice standard frequency range.
step_size: float
The time between analysis windows (in seconds).
window_size: float
The size of the analysis window (in seconds). Must be long enough
to contain at least 4 periods at the minimum frequency.
sample_rate: int
The number of samples in a second.
log_scores: bool
Whether to represent the jitter/shimmer/hnr/gne on a log scale,
as these features are typically close to zero.
eps: float
The minimum value before log transformation, default of
1e-3 results in a maximum value of 30 dB.
sma_neighbors: int
Number of frames to average -- default 3
n_mels: int (default: 23)
Number of filters to use for creating filterbank.
n_mfcc: int (default: 4)
Number of output coefficients
Example
-------
>>> audio = torch.rand(1, 16000)
>>> feature_maker = VocalFeatures()
>>> vocal_features = feature_maker(audio)
>>> vocal_features.shape
torch.Size([1, 96, 17])
"""
def __init__(
self,
min_f0_Hz: int = 80,
max_f0_Hz: int = 300,
step_size: float = 0.01,
window_size: float = 0.05,
sample_rate: int = 16000,
log_scores: bool = True,
eps: float = 1e-3,
sma_neighbors: int = 3,
n_mels: int = 23,
n_mfcc: int = 4,
):
super().__init__()
# Convert arguments to sample counts. Max lag corresponds to min f0 and vice versa.
self.step_samples = int(step_size * sample_rate)
self.window_samples = int(window_size * sample_rate)
self.max_lag = int(sample_rate / min_f0_Hz)
self.min_lag = int(sample_rate / max_f0_Hz)
self.sample_rate = sample_rate
self.log_scores = log_scores
self.eps = eps
self.sma_neighbors = sma_neighbors
assert self.max_lag * PERIODIC_NEIGHBORS <= self.window_samples, (
f"Need at least {PERIODIC_NEIGHBORS} periods in a window"
)
self.compute_fbanks = Filterbank(
sample_rate=sample_rate,
n_fft=self.window_samples,
n_mels=n_mels,
)
self.compute_dct = DCT(input_size=n_mels, n_out=n_mfcc)
self.compute_gne = partial(
compute_gne, frame_len=window_size, hop_len=step_size
)
def forward(self, audio: torch.Tensor):
"""Compute voice features.
Arguments
---------
audio: torch.Tensor
The audio signal to be converted to voice features.
Returns
-------
features: torch.Tensor
A [batch, frame, 13+n_mfcc] tensor with the following features per-frame.
* autocorr_f0: A per-frame estimate of the f0 in Hz.
* autocorr_hnr: harmonicity-to-noise ratio for each frame.
* periodic_jitter: Average deviation in period length.
* periodic_shimmer: Average deviation in amplitude per period.
* gne: The glottal-to-noise-excitation ratio.
* spectral_centroid: "center-of-mass" for spectral frames.
* spectral_spread: avg distance from centroid for spectral frames.
* spectral_skew: asymmetry of spectrum about the centroid.
* spectral_kurtosis: tailedness of spectrum.
* spectral_entropy: The peakiness of the spectrum.
* spectral_flatness: The ratio of geometric mean to arithmetic mean.
* spectral_crest: The ratio of spectral maximum to arithmetic mean.
* spectral_flux: The 2-normed diff between successive spectral values.
* mfcc_{0-n_mfcc}: The mel cepstral coefficients.
"""
assert audio.dim() == 2, (
"Expected audio to be 2-dimensional, [batch, samples]"
)
# Use frame-based autocorrelation to estimate harmonicity and f0
frames = audio.unfold(
dimension=-1, size=self.window_samples, step=self.step_samples
)
harmonicity, best_lags = compute_autocorr_features(
frames, self.min_lag, self.max_lag
)
f0 = self.sample_rate / best_lags
# Autocorrelation score is the source of harmonicity here, 1-harmonicity is noise
# See "Harmonic to Noise Ratio Measurement - Selection of Window and Length"
# By J. Fernandez, F. Teixeira, V. Guedes, A. Junior, and J. P. Teixeira
# Ratio is dominated by denominator, just ignore numerator here.
hnr = 1 - harmonicity
jitter, shimmer = compute_periodic_features(frames, best_lags)
# Because of resampling, gne may not be exactly same size
gne = self.compute_gne(audio, self.sample_rate)
if gne.size(1) > frames.size(1):
gne = gne[:, : frames.size(1)]
# These features all are close to 0 most of the time, use log to differentiate
if self.log_scores:
hnr = -10 * hnr.clamp(min=self.eps).log10()
jitter = -10 * jitter.clamp(min=self.eps).log10()
shimmer = -10 * shimmer.clamp(min=self.eps).log10()
gne = -10 * (1 - gne).clamp(min=self.eps).log10()
# Compute spectrum for remaining features
hann = torch.hann_window(self.window_samples, device=frames.device)
spectrum = torch.abs(torch.fft.rfft(frames * hann.view(1, 1, -1)))
spectral_features = compute_spectral_features(spectrum)
mfccs = self.compute_dct(self.compute_fbanks(spectrum))
# Combine all features into a single tensor
features = torch.stack((f0, hnr, jitter, shimmer, gne), dim=-1)
features = torch.cat((features, spectral_features, mfccs), dim=-1)
# Compute moving average (as OpenSMILE does)
if self.sma_neighbors > 1:
features = moving_average(features, dim=1, n=self.sma_neighbors)
return features
def moving_average(features, dim=1, n=3):
"""Computes moving average on a given dimension.
Arguments
---------
features: torch.Tensor
The feature tensor to smooth out.
dim: int
The time dimension (for smoothing).
n: int
The number of points in the moving average
Returns
-------
smoothed_features: torch.Tensor
The features after the moving average is applied.
Example
-------
>>> feats = torch.tensor([[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0]])
>>> moving_average(feats)
tensor([[0.5000, 0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.5000]])
"""
features = features.transpose(dim, -1)
pad = n // 2
features = torch.nn.functional.avg_pool1d(
features, kernel_size=n, padding=pad, stride=1, count_include_pad=False
)
return features.transpose(dim, -1)