|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | from dataclasses import dataclass |
| 11 | +from functools import partial |
11 | 12 | from typing import Optional |
12 | 13 |
|
13 | 14 | import torch |
|
23 | 24 | Filterbank, |
24 | 25 | spectral_magnitude, |
25 | 26 | ) |
| 27 | +from speechbrain.processing.vocal_features import ( |
| 28 | + PERIODIC_NEIGHBORS, |
| 29 | + compute_autocorr_features, |
| 30 | + compute_gne, |
| 31 | + compute_periodic_features, |
| 32 | + compute_spectral_features, |
| 33 | +) |
26 | 34 | from speechbrain.utils.autocast import fwd_default_precision |
27 | 35 | from speechbrain.utils.filter_analysis import FilterProperties |
28 | 36 |
|
| 37 | +VOICE_EPSILON = 1e-3 |
| 38 | + |
29 | 39 |
|
30 | 40 | class Fbank(torch.nn.Module): |
31 | 41 | """Generate features for input to the speech pipeline. |
@@ -660,3 +670,184 @@ def get_filter_properties(self) -> FilterProperties: |
660 | 670 |
|
661 | 671 | def make_streaming_context(self) -> StreamingFeatureWrapperContext: |
662 | 672 | return StreamingFeatureWrapperContext(None) |
| 673 | + |
| 674 | + |
| 675 | +class VocalFeatures(torch.nn.Module): |
| 676 | + """Estimates the vocal characteristics of a signal in four categories of features: |
| 677 | + * Autocorrelation-based |
| 678 | + * Period-based (jitter/shimmer) |
| 679 | + * Spectrum-based |
| 680 | + * MFCCs |
| 681 | +
|
| 682 | + Arguments |
| 683 | + --------- |
| 684 | + min_f0_Hz: int |
| 685 | + The minimum allowed fundamental frequency, to reduce octave errors. |
| 686 | + Default is 80 Hz, based on human voice standard frequency range. |
| 687 | + max_f0_Hz: int |
| 688 | + The maximum allowed fundamental frequency, to reduce octave errors. |
| 689 | + Default is 300 Hz, based on human voice standard frequency range. |
| 690 | + step_size: float |
| 691 | + The time between analysis windows (in seconds). |
| 692 | + window_size: float |
| 693 | + The size of the analysis window (in seconds). Must be long enough |
| 694 | + to contain at least 4 periods at the minimum frequency. |
| 695 | + sample_rate: int |
| 696 | + The number of samples in a second. |
| 697 | + log_scores: bool |
| 698 | + Whether to represent the jitter/shimmer/hnr on a log scale. |
| 699 | + eps: float |
| 700 | + The minimum value before log transformation, default of |
| 701 | + 1e-3 results in a maximum value of 30 dB. |
| 702 | + sma_neighbors: int |
| 703 | + Number of frames to average -- default 3 |
| 704 | + """ |
| 705 | + |
| 706 | + def __init__( |
| 707 | + self, |
| 708 | + min_f0_Hz: int = 80, |
| 709 | + max_f0_Hz: int = 300, |
| 710 | + step_size: float = 0.01, |
| 711 | + window_size: float = 0.05, |
| 712 | + sample_rate: int = 16000, |
| 713 | + log_scores: bool = True, |
| 714 | + eps: float = 1e-3, |
| 715 | + sma_neighbors: int = 3, |
| 716 | + ): |
| 717 | + super().__init__() |
| 718 | + |
| 719 | + # Convert arguments to sample counts. Max lag corresponds to min f0 and vice versa. |
| 720 | + self.step_samples = int(step_size * sample_rate) |
| 721 | + self.window_samples = int(window_size * sample_rate) |
| 722 | + self.max_lag = int(sample_rate / min_f0_Hz) |
| 723 | + self.min_lag = int(sample_rate / max_f0_Hz) |
| 724 | + self.sample_rate = sample_rate |
| 725 | + self.log_scores = log_scores |
| 726 | + self.eps = eps |
| 727 | + self.sma_neighbors = sma_neighbors |
| 728 | + |
| 729 | + assert ( |
| 730 | + self.max_lag * PERIODIC_NEIGHBORS <= self.window_samples |
| 731 | + ), f"Need at least {PERIODIC_NEIGHBORS} periods in a window" |
| 732 | + |
| 733 | + n_mels, n_mfcc = 23, 4 |
| 734 | + self.compute_fbanks = Filterbank( |
| 735 | + sample_rate=sample_rate, |
| 736 | + n_fft=self.window_samples, |
| 737 | + n_mels=n_mels, |
| 738 | + ) |
| 739 | + self.compute_dct = DCT(input_size=n_mels, n_out=n_mfcc) |
| 740 | + self.compute_gne = partial( |
| 741 | + compute_gne, frame_len=window_size, hop_len=step_size |
| 742 | + ) |
| 743 | + |
| 744 | + def forward(self, audio: torch.Tensor): |
| 745 | + """Compute voice features. |
| 746 | +
|
| 747 | + Arguments |
| 748 | + --------- |
| 749 | + audio: torch.Tensor |
| 750 | + The audio signal to be converted to voice features. |
| 751 | +
|
| 752 | + Returns |
| 753 | + ------- |
| 754 | + features: torch.Tensor |
| 755 | + A [batch, frame, 17] tensor with the following features per-frame. |
| 756 | + * autocorr_f0: A per-frame estimate of the f0 in Hz. |
| 757 | + * autocorr_hnr: harmonicity-to-noise ratio for each frame. |
| 758 | + * periodic_jitter: Average deviation in period length. |
| 759 | + * periodic_shimmer: Average deviation in amplitude per period. |
| 760 | + * gne: The glottal-to-noise-excitation ratio. |
| 761 | + * spectral_centroid: "center-of-mass" for spectral frames. |
| 762 | + * spectral_spread: avg distance from centroid for spectral frames. |
| 763 | + * spectral_skew: asymmetry of spectrum about the centroid. |
| 764 | + * spectral_kurtosis: tailedness of spectrum. |
| 765 | + * spectral_entropy: The peakiness of the spectrum. |
| 766 | + * spectral_flatness: The ratio of geometric mean to arithmetic mean. |
| 767 | + * spectral_crest: The ratio of spectral maximum to arithmetic mean. |
| 768 | + * spectral_flux: The 2-normed diff between successive spectral values. |
| 769 | + * mfcc_0: The first mel cepstral coefficient. |
| 770 | + * mfcc_1: The second mel cepstral coefficient. |
| 771 | + * mfcc_2: The third mel cepstral coefficient. |
| 772 | + * mfcc_3: The fourth mel cepstral coefficient. |
| 773 | + """ |
| 774 | + assert ( |
| 775 | + audio.dim() == 2 |
| 776 | + ), "Expected audio to be 2-dimensional, [batch, samples]" |
| 777 | + |
| 778 | + # Use frame-based autocorrelation to estimate harmonicity and f0 |
| 779 | + frames = audio.unfold( |
| 780 | + dimension=-1, size=self.window_samples, step=self.step_samples |
| 781 | + ) |
| 782 | + harmonicity, best_lags = compute_autocorr_features( |
| 783 | + frames, self.min_lag, self.max_lag |
| 784 | + ) |
| 785 | + f0 = self.sample_rate / best_lags |
| 786 | + |
| 787 | + # Autocorrelation score is the source of harmonicity here, 1-harmonicity is noise |
| 788 | + # See "Harmonic to Noise Ratio Measurement - Selection of Window and Length" |
| 789 | + # By J. Fernandez, F. Teixeira, V. Guedes, A. Junior, and J. P. Teixeira |
| 790 | + # Ratio is dominated by denominator, just ignore numerator here. |
| 791 | + hnr = 1 - harmonicity |
| 792 | + jitter, shimmer = compute_periodic_features(frames, best_lags) |
| 793 | + |
| 794 | + # Because of resampling, gne may not be exactly same size |
| 795 | + gne = self.compute_gne(audio, self.sample_rate) |
| 796 | + if gne.size(1) > frames.size(1): |
| 797 | + gne = gne[:, : frames.size(1)] |
| 798 | + |
| 799 | + # These features all are close to 0 most of the time, use log to differentiate |
| 800 | + if self.log_scores: |
| 801 | + hnr = -10 * hnr.clamp(min=self.eps).log10() |
| 802 | + jitter = -10 * jitter.clamp(min=self.eps).log10() |
| 803 | + shimmer = -10 * shimmer.clamp(min=self.eps).log10() |
| 804 | + gne = -10 * (1 - gne).clamp(min=self.eps).log10() |
| 805 | + |
| 806 | + # Compute spectrum for remaining features |
| 807 | + hann = torch.hann_window(self.window_samples, device=frames.device) |
| 808 | + spectrum = torch.abs(torch.fft.rfft(frames * hann.view(1, 1, -1))) |
| 809 | + spectral_features = compute_spectral_features(spectrum) |
| 810 | + mfccs = self.compute_dct(self.compute_fbanks(spectrum)) |
| 811 | + |
| 812 | + # Combine all features into a single tensor |
| 813 | + features = torch.stack((f0, hnr, jitter, shimmer, gne), dim=-1) |
| 814 | + features = torch.cat((features, spectral_features, mfccs), dim=-1) |
| 815 | + |
| 816 | + # Compute moving average (as OpenSMILE does) |
| 817 | + if self.sma_neighbors > 1: |
| 818 | + features = moving_average(features, dim=1, n=self.sma_neighbors) |
| 819 | + |
| 820 | + return features |
| 821 | + |
| 822 | + |
| 823 | +def moving_average(features, dim=1, n=3): |
| 824 | + """Computes moving average on a given dimension. |
| 825 | +
|
| 826 | + Arguments |
| 827 | + --------- |
| 828 | + features: torch.Tensor |
| 829 | + The feature tensor to smooth out. |
| 830 | + dim: int |
| 831 | + The time dimension (for smoothing). |
| 832 | + n: int |
| 833 | + The number of points in the moving average |
| 834 | +
|
| 835 | + Returns |
| 836 | + ------- |
| 837 | + smoothed_features: torch.Tensor |
| 838 | + The features after the moving average is applied. |
| 839 | +
|
| 840 | + Example |
| 841 | + ------- |
| 842 | + >>> feats = torch.tensor([[0., 1., 0., 1., 0., 1., 0.]]) |
| 843 | + >>> moving_average(feats) |
| 844 | + tensor([[0.5000, 0.3333, 0.6667, 0.3333, 0.6667, 0.3333, 0.5000]]) |
| 845 | + """ |
| 846 | + features = features.transpose(dim, -1) |
| 847 | + |
| 848 | + pad = n // 2 |
| 849 | + features = torch.nn.functional.avg_pool1d( |
| 850 | + features, kernel_size=n, padding=pad, stride=1, count_include_pad=False |
| 851 | + ) |
| 852 | + |
| 853 | + return features.transpose(dim, -1) |
0 commit comments