Skip to content

Commit b18b278

Browse files
committed
Move main vocal feature extraction code to lobes
1 parent cb4b3b4 commit b18b278

3 files changed

Lines changed: 383 additions & 322 deletions

File tree

docs/tutorials/preprocessing/voice-analysis.ipynb

Lines changed: 77 additions & 76 deletions
Large diffs are not rendered by default.

speechbrain/lobes/features.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass
11+
from functools import partial
1112
from typing import Optional
1213

1314
import torch
@@ -23,9 +24,18 @@
2324
Filterbank,
2425
spectral_magnitude,
2526
)
27+
from speechbrain.processing.vocal_features import (
28+
PERIODIC_NEIGHBORS,
29+
compute_autocorr_features,
30+
compute_gne,
31+
compute_periodic_features,
32+
compute_spectral_features,
33+
)
2634
from speechbrain.utils.autocast import fwd_default_precision
2735
from speechbrain.utils.filter_analysis import FilterProperties
2836

37+
VOICE_EPSILON = 1e-3
38+
2939

3040
class Fbank(torch.nn.Module):
3141
"""Generate features for input to the speech pipeline.
@@ -660,3 +670,184 @@ def get_filter_properties(self) -> FilterProperties:
660670

661671
def make_streaming_context(self) -> StreamingFeatureWrapperContext:
662672
return StreamingFeatureWrapperContext(None)
673+
674+
675+
class VocalFeatures(torch.nn.Module):
676+
"""Estimates the vocal characteristics of a signal in four categories of features:
677+
* Autocorrelation-based
678+
* Period-based (jitter/shimmer)
679+
* Spectrum-based
680+
* MFCCs
681+
682+
Arguments
683+
---------
684+
min_f0_Hz: int
685+
The minimum allowed fundamental frequency, to reduce octave errors.
686+
Default is 80 Hz, based on human voice standard frequency range.
687+
max_f0_Hz: int
688+
The maximum allowed fundamental frequency, to reduce octave errors.
689+
Default is 300 Hz, based on human voice standard frequency range.
690+
step_size: float
691+
The time between analysis windows (in seconds).
692+
window_size: float
693+
The size of the analysis window (in seconds). Must be long enough
694+
to contain at least 4 periods at the minimum frequency.
695+
sample_rate: int
696+
The number of samples in a second.
697+
log_scores: bool
698+
Whether to represent the jitter/shimmer/hnr on a log scale.
699+
eps: float
700+
The minimum value before log transformation, default of
701+
1e-3 results in a maximum value of 30 dB.
702+
sma_neighbors: int
703+
Number of frames to average -- default 3
704+
"""
705+
706+
def __init__(
707+
self,
708+
min_f0_Hz: int = 80,
709+
max_f0_Hz: int = 300,
710+
step_size: float = 0.01,
711+
window_size: float = 0.05,
712+
sample_rate: int = 16000,
713+
log_scores: bool = True,
714+
eps: float = 1e-3,
715+
sma_neighbors: int = 3,
716+
):
717+
super().__init__()
718+
719+
# Convert arguments to sample counts. Max lag corresponds to min f0 and vice versa.
720+
self.step_samples = int(step_size * sample_rate)
721+
self.window_samples = int(window_size * sample_rate)
722+
self.max_lag = int(sample_rate / min_f0_Hz)
723+
self.min_lag = int(sample_rate / max_f0_Hz)
724+
self.sample_rate = sample_rate
725+
self.log_scores = log_scores
726+
self.eps = eps
727+
self.sma_neighbors = sma_neighbors
728+
729+
assert (
730+
self.max_lag * PERIODIC_NEIGHBORS <= self.window_samples
731+
), f"Need at least {PERIODIC_NEIGHBORS} periods in a window"
732+
733+
n_mels, n_mfcc = 23, 4
734+
self.compute_fbanks = Filterbank(
735+
sample_rate=sample_rate,
736+
n_fft=self.window_samples,
737+
n_mels=n_mels,
738+
)
739+
self.compute_dct = DCT(input_size=n_mels, n_out=n_mfcc)
740+
self.compute_gne = partial(
741+
compute_gne, frame_len=window_size, hop_len=step_size
742+
)
743+
744+
def forward(self, audio: torch.Tensor):
745+
"""Compute voice features.
746+
747+
Arguments
748+
---------
749+
audio: torch.Tensor
750+
The audio signal to be converted to voice features.
751+
752+
Returns
753+
-------
754+
features: torch.Tensor
755+
A [batch, frame, 17] tensor with the following features per-frame.
756+
* autocorr_f0: A per-frame estimate of the f0 in Hz.
757+
* autocorr_hnr: harmonicity-to-noise ratio for each frame.
758+
* periodic_jitter: Average deviation in period length.
759+
* periodic_shimmer: Average deviation in amplitude per period.
760+
* gne: The glottal-to-noise-excitation ratio.
761+
* spectral_centroid: "center-of-mass" for spectral frames.
762+
* spectral_spread: avg distance from centroid for spectral frames.
763+
* spectral_skew: asymmetry of spectrum about the centroid.
764+
* spectral_kurtosis: tailedness of spectrum.
765+
* spectral_entropy: The peakiness of the spectrum.
766+
* spectral_flatness: The ratio of geometric mean to arithmetic mean.
767+
* spectral_crest: The ratio of spectral maximum to arithmetic mean.
768+
* spectral_flux: The 2-normed diff between successive spectral values.
769+
* mfcc_0: The first mel cepstral coefficient.
770+
* mfcc_1: The second mel cepstral coefficient.
771+
* mfcc_2: The third mel cepstral coefficient.
772+
* mfcc_3: The fourth mel cepstral coefficient.
773+
"""
774+
assert (
775+
audio.dim() == 2
776+
), "Expected audio to be 2-dimensional, [batch, samples]"
777+
778+
# Use frame-based autocorrelation to estimate harmonicity and f0
779+
frames = audio.unfold(
780+
dimension=-1, size=self.window_samples, step=self.step_samples
781+
)
782+
harmonicity, best_lags = compute_autocorr_features(
783+
frames, self.min_lag, self.max_lag
784+
)
785+
f0 = self.sample_rate / best_lags
786+
787+
# Autocorrelation score is the source of harmonicity here, 1-harmonicity is noise
788+
# See "Harmonic to Noise Ratio Measurement - Selection of Window and Length"
789+
# By J. Fernandez, F. Teixeira, V. Guedes, A. Junior, and J. P. Teixeira
790+
# Ratio is dominated by denominator, just ignore numerator here.
791+
hnr = 1 - harmonicity
792+
jitter, shimmer = compute_periodic_features(frames, best_lags)
793+
794+
# Because of resampling, gne may not be exactly same size
795+
gne = self.compute_gne(audio, self.sample_rate)
796+
if gne.size(1) > frames.size(1):
797+
gne = gne[:, : frames.size(1)]
798+
799+
# These features all are close to 0 most of the time, use log to differentiate
800+
if self.log_scores:
801+
hnr = -10 * hnr.clamp(min=self.eps).log10()
802+
jitter = -10 * jitter.clamp(min=self.eps).log10()
803+
shimmer = -10 * shimmer.clamp(min=self.eps).log10()
804+
gne = -10 * (1 - gne).clamp(min=self.eps).log10()
805+
806+
# Compute spectrum for remaining features
807+
hann = torch.hann_window(self.window_samples, device=frames.device)
808+
spectrum = torch.abs(torch.fft.rfft(frames * hann.view(1, 1, -1)))
809+
spectral_features = compute_spectral_features(spectrum)
810+
mfccs = self.compute_dct(self.compute_fbanks(spectrum))
811+
812+
# Combine all features into a single tensor
813+
features = torch.stack((f0, hnr, jitter, shimmer, gne), dim=-1)
814+
features = torch.cat((features, spectral_features, mfccs), dim=-1)
815+
816+
# Compute moving average (as OpenSMILE does)
817+
if self.sma_neighbors > 1:
818+
features = moving_average(features, dim=1, n=self.sma_neighbors)
819+
820+
return features
821+
822+
823+
def moving_average(features, dim=1, n=3):
824+
"""Computes moving average on a given dimension.
825+
826+
Arguments
827+
---------
828+
features: torch.Tensor
829+
The feature tensor to smooth out.
830+
dim: int
831+
The time dimension (for smoothing).
832+
n: int
833+
The number of points in the moving average
834+
835+
Returns
836+
-------
837+
smoothed_features: torch.Tensor
838+
The features after the moving average is applied.
839+
840+
Example
841+
-------
842+
>>> feats = torch.tensor([[0., 1., 0., 1., 0., 1., 0.]])
843+
>>> moving_average(feats)
844+
tensor([[0.5000, 0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.5000]])
845+
"""
846+
features = features.transpose(dim, -1)
847+
848+
pad = n // 2
849+
features = torch.nn.functional.avg_pool1d(
850+
features, kernel_size=n, padding=pad, stride=1, count_include_pad=False
851+
)
852+
853+
return features.transpose(dim, -1)

0 commit comments

Comments
 (0)