Skip to content

Commit 2f9e2fd

Browse files
fpaissanycemsubakanFrancesco Paissanasumagicpoonehmousavi
authored
Adding recipe for Listenable Maps for Audio Classifiers (#2538)
* works on cnn14 -- but have a bad checkpoint * fixed l2i as well * fixed acc in l2i * fix not listenable * updated logging for eval * a bit less verbose * printing at sample level * fix logging - was missing avg * was messing up in the forward * now running train_piq.py * minor corrections * fix l2i training with wham! * fixed l2i computation * linters * add check for wham usage in eval * add sample saving during eval * bug fixes * added predictions info to the logging * fixed id for overlap test * cutting sample before saving * fixed l2i sampling rate * fixed random seed so eval will match * running on full set * faithfulness fix * remove pdb * fix smoothgrad and IG * fix nmf for pre-training * removed nmf reconstructions * truncated gaussian fix for smoothgrad * fix nans in sensitivity * better l2i psi network * saving to a different folder. helps not overriding experiments.. * fix l2i * fix csv logging of exps * add guided backprop * added gradcam. guided backprop and guided gradcam need debugging * l2i encoder 1D * mel only - ao * eval for mel only * changed logging to simple write * hardcoded checkpoint - to run on cc * save everything in one folder * remove joblib import * fixed eval? * fix eval again.. * maybe now? * trying on cc * add eval_outdir * runs full eval * l2i with updated psi * update gitignore * l2i logging different loss values * add us8k classifier * us8k interpretations * fixed guided backprop and guided gradcam * add shap * normalizing shap attributions * adding us8k prepare in interp.. * eval on ID * fixed backward compatibility * added multiclass classification * eval xplorer v1 * eval xplorer v2 * implemented multi label interpretation * update the loss function in multilabel interpretations * evaluation explorer - minor fixes * add roar * roar test * just removing a print... * add roar script * adding the user study parsing script * savefigs * fix to roar hparam * minor * extract samples for user study * fix bug roar * fixed roar * fix another copy-paste error * MRT eval * roar with random baseline * fix np seed * computes mrt metrics * saving masks for mrt viz * remove rand baseline roar * abs * gradcam eval * fix class * add mrt to l2i * train piq us8k * param in mrt_evaluator * add viz * adding the latest * fixing path problems for multilabelstuff * changed the loss function to output 10 masks * more standard maskout term * changed encoder loading to local * added accuracy computation * removed unnecessary evaluation methods * added all ones mask and average energy computation * fixed the bug for whitenoise * pushing eval later * l2i new ood * removing useless files * cleaning up classification as well * removing useless hparams in interpret * more useless files * old linters * fix paths * fix paths * update Cnn14 * restored old piq file * wham on PIQ * Adding LMAC - needs refactor (#5) * WHAM-ing the data * AO on conv2d classifier * added interpretability metrics * fix debug steps -- updated * minor to train_piq * fix saving interpretations * add wham! for L2I * fix l2i eval * add NCC * cross correlation w/ batching * checked crosscor * finish finetuning script * switch to l1 * linters * add binarized oracle w/ BCE * fix compute loss in finetuning while saving samples * comparison script * fix 0dB mixtures * add original wav to comparison * just path to new classifier * just committing new checkpoint for L2I * add NMF image logging for debug * fix bug in viz L2I * log the number of finetuning masks * lower crosscor thr * fix acc * align L2I debugging w/ PIQ script * fixed accuracy computation for L2I * L2I with variable number of components (K=200) * debugging l2i... * update hparams * fixed oracle source * fixed wrong sources and running finetuning experiments.. * add AST as classifier * hparams ast -- still not converging * add ast augmentation * update training script after merge * with augmentations is better * just pushing hparams * classification with CE * conv2d fix for CE * playing with AST augmentation * fixed thresholding * starting to experiment with no wham noise stuff * add wham noise option in classifier training, dot prod correlation in finetuning * single mask training * added zero grad * added the entropy loss * implemented a psi function for cnn14 * Update README.md * added stft-mel transformation learning * add latest eval setup - working on gradient-based * removed unused brain -- was causing issues in weights loading.. * training l2i on this classifier * add l2i eval -- removing mosaic; not well defined in the case of L2I * removed old png file * debugging eval weight loading.. * was always using vq * fixed eval AO * fixed eval -- now everything's fine also for L2I * better numerical stability * handling quantus assertionerror * add saliency from captum * updated smoothgrad for captum * added norm to saliency * IG from captum * starting gradient-base eval on cnn14... * commit before merge * works on cnn14 -- but have a bad checkpoint * fixed l2i as well * fixed acc in l2i * fix not listenable * updated logging for eval * a bit less verbose * printing at sample level * fix logging - was missing avg * was messing up in the forward * now running train_piq.py * minor corrections * fix l2i training with wham! * fixed l2i computation * linters * add check for wham usage in eval * add sample saving during eval * bug fixes * added predictions info to the logging * fixed id for overlap test * cutting sample before saving * fixed l2i sampling rate * fixed random seed so eval will match * running on full set * faithfulness fix * remove pdb * fix smoothgrad and IG * fix nmf for pre-training * removed nmf reconstructions * truncated gaussian fix for smoothgrad * fix nans in sensitivity * better l2i psi network * saving to a different folder. helps not overriding experiments.. * fix l2i * fix csv logging of exps * add guided backprop * added gradcam. guided backprop and guided gradcam need debugging * l2i encoder 1D * mel only - ao * eval for mel only * changed logging to simple write * hardcoded checkpoint - to run on cc * save everything in one folder * remove joblib import * fixed eval? * fix eval again.. * maybe now? * trying on cc * add eval_outdir * runs full eval * l2i with updated psi * update gitignore * l2i logging different loss values * add us8k classifier * us8k interpretations * fixed guided backprop and guided gradcam * add shap * normalizing shap attributions * adding us8k prepare in interp.. * eval on ID * fixed backward compatibility * added multiclass classification * eval xplorer v1 * eval xplorer v2 * implemented multi label interpretation * update the loss function in multilabel interpretations * evaluation explorer - minor fixes * add roar * roar test * just removing a print... * add roar script * adding the user study parsing script * savefigs * fix to roar hparam * minor * extract samples for user study * fix bug roar * fixed roar * fix another copy-paste error * MRT eval * roar with random baseline * fix np seed * computes mrt metrics * saving masks for mrt viz * remove rand baseline roar * abs * gradcam eval * fix class * add mrt to l2i * train piq us8k * param in mrt_evaluator * add viz * adding the latest * fixing path problems for multilabelstuff * changed the loss function to output 10 masks * more standard maskout term * changed encoder loading to local * added accuracy computation * removed unnecessary evaluation methods * added all ones mask and average energy computation * fixed the bug for whitenoise * pushing eval later * l2i new ood * removing useless files * cleaning up classification as well * removing useless hparams in interpret * more useless files * old linters * fix paths * fix paths * update Cnn14 * restored old piq file * wham on PIQ --------- Co-authored-by: Cem Subakan <csubakan@gmail.com> Co-authored-by: Francesco Paissan <fpaissan@cedar1.cedar.computecanada.ca> * removed useless code. needs to be modified to run with self.interpret_sample * parent class and piq mods * fix fn names * simplify viz * move data prep function * L2I with parent class * removed 1 decoderator * commenting viz_ints. need std * unifying viz * change fn call * removed abstract class * disable viz_ints * rm bl comp * l2i viz * remove l2i fid * add lens * removed some metrics * extra_metric fix * removed another metric * removed another metric * starting to std viz * inp fid * fix ic * removing metrics as they will be compute elsewhere * viz piq * viz piq remove mask_ll * uniform piq viz * PIQ fits parent class * starting to unify metrics eval * fixed metrics -- missing SPS and COMP * linters * lmac into template * update lmac hparams * minor * not converging * converging now * computing metrics * computing extra metrics * extra metrics for l2i * starting SPS and COMP * Adds quantus SPS and COMP metrics to the refactoring code (#6) * starting to add quantus metrics * add sps and com * quantus metrics L2I * add quantus reqs * removed unused file * still throws strange error * ood eval * fixed paddedbatch stuff * eval L2I * remove useless files * using right wham preparation * removing model wrapper as it is not needed * fix ID samples * fix linters * model finetuning test * pretrained_PIQ -> pretrained_interpreter * update README.md * added README instructions for training with WHAM! * removing the dataset tag on experiment name * Fix Checks (#8) * Skip lazy imports when the caller is inspect.py This avoids having certain inspect functions import our lazy modules when we don't want them to. `getframeinfo` in particular appears to do it, and this gets called by PyTorch at some point. IPython might also be doing it but autocomplete still seems to work. This does not appear to break anything. Added test for hyperpyyaml to ensure we're not breaking that. * SSL_Semantic_Token _ new PR (#2509) * remove unnecassry files and move to dasb * remove extra recepie from test * update ljspeech qunatization recepie * add discrete_ssl and remove extra files * fix precommit * update kmeans and add tokeizer for postprocessing * fix precommit * Update discrete_ssl.py * fix clone warning --------- Co-authored-by: Mirco Ravanelli <mirco.ravanelli@gmail.com> * _ensure_module Raises docstring * Expose `ensure_module` so that docs get generated for it This is already an internal class anyway, and this is safe to call. * Update actions/setup-python * Use `uv` in test CI + merge some dep installs The consequence is faster dependency installation. Merging some of the dependency installs helps avoid some packages being reinstalled from one line to the next. Additionally, CPU versions are specified when relevant, to avoid downloading CUDA stuff the CI can't use anyway. * Use `uv` in doc CI + merge some dep installs Similar rationale as for the test CI * Parallelize doc generation with Sphinx This does not affect the entire doc generation process but should allow some minor multithreading even with the 2-core CI workers. * Enable `uv` caching on the test CI * Enable `uv` caching on the docs CI * CTC-only training recipes for LibriSpeech (code from Samsung AI Cambridge) (#2290) CTC-only pre-training of conformer and branchformer. --------- Co-authored-by: Shucong Zhang/Embedded AI /SRUK/Engineer/Samsung Electronics <s1.zhang@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com> Co-authored-by: Adel Moumen <88119391+Adel-Moumen@users.noreply.github.com> Co-authored-by: Parcollet Titouan <titouan.parcollet@univ-avignon.fr> * Update CommonVoice transformer recipes (code from Samsung AI Center Cambridge) (#2465) * Update CV transformer recipes to match latest results with conformer. --------- Co-authored-by: Titouan Parcollet/Embedded AI /SRUK/Engineer/Samsung Electronics <t.parcollet@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Mirco Ravanelli <mirco.ravanelli@gmail.com> Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com> * Whisper improvements: flash attention, KV caching, lang_id, translation, training... (#2450) Whisper improvements: - flash attention - kv caching - lang identifaction - translation - finetuning amelioration ... and more ... * Update README.md * precommit * update zed download link (#2514) * `RelPosEncXL` refactor and precision fixes (#2498) * Add `RelPosEncXL.make_pe`, rework precision handling * Rework RelPosEncXL output dtype selection * Fix in-place input normalization when using `sentence`/`speaker` norm (#2504) * fix LOCAL_RANK to be RANK in if_main_process (#2506) * Fix Separation and Enhancement recipes behavior when NaN encountered (#2524) * Fix Separation and Enhancement recipes behavior when NaN encountered * Formatting using precommit hooks * Lock torch version in requirements.txt (#2528) * Fix compatibility for torchaudio versions without `.io` (#2532) This avoids having the Python interpreter attempt to resolve the type annotation directly. * fix docstrings * consistency tests - classification * consistency tests - classification * consistency tests - interpret * default to no wham * fix after tests pass * fix after tests pass * tests after that * fix consistency --------- Co-authored-by: asu <sdelang@sdelang.fr> Co-authored-by: Pooneh Mousavi <moosavi.pooneh@gmail.com> Co-authored-by: Mirco Ravanelli <mirco.ravanelli@gmail.com> Co-authored-by: shucongzhang <104781888+shucongzhang@users.noreply.github.com> Co-authored-by: Shucong Zhang/Embedded AI /SRUK/Engineer/Samsung Electronics <s1.zhang@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com> Co-authored-by: Adel Moumen <88119391+Adel-Moumen@users.noreply.github.com> Co-authored-by: Parcollet Titouan <titouan.parcollet@univ-avignon.fr> Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com> Co-authored-by: Titouan Parcollet/Embedded AI /SRUK/Engineer/Samsung Electronics <t.parcollet@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Yingzhi WANG <41187612+BenoitWang@users.noreply.github.com> Co-authored-by: Peter Plantinga <plantinga.peter@protonmail.com> Co-authored-by: Séverin <123748182+SevKod@users.noreply.github.com> * added wham hparams to vit.yaml * added focalnet wham hyperparams * add eval info * add automatic wham download * additional instructions on README * wham prepare uses explicit parameters * wham docstrings * edited the instructions on different contamination types * removing the table * revert changes to gitignore * added comments on how to specify custom model * precommit hooks * fixed eval.py bug and more instructions in README.md * remove checkpoint to avoid loading from exp folder * load pretrained interpreter * save always during test * remove checkpointer call in eval.py * added few more explanations for l2i * fixed the nmf dictionary error * fix viz argument for l2i * added a comment for WHAM! noise * setting the wham to False in vit and focalnet recipes * fixed the faithfulness computation in PIQ and added AD AG AI COMPS SPS * minor documentation improvements * fixing the bug in SPS computation * formatting * Update README.md * set manifest preparation to True * fix device (not to add in yaml as it is a runnuing hparam) * added the missing docstrings for complexity sparseness faithfulness * fixed the header in eval.py * added missing l2i command in train_l2i.py * fixes to train_lmac.py * description for classifier_temp * added comments for pretrained_interpreter and ljspeech_path * updated README to have more information on how to use LJSpeech * added information for piq_vit.yaml and piq_focalnet.yaml * added more explanation for LJSpeech downloading * added missing use_melspectra_log1p attribute to piq_vit.yaml and piq_focalnet.yaml * added an assert in eval.py for the pretrained path * updates to the readme, added table, updated l2i to print quantus metrics * Update README.md * added the description of pretained_interpreter in README.md. * fixed the problem in vit * fixing l2i tests * fixed ESC50.csv * fixed the yaml tets * added links to files * fixed docstring tests * bug fix on psi model * removing the classes from PIQ.py * fixes in L2I psi classes * handling sps comp exceptions * added the dropbox links * Update README.md --------- Co-authored-by: Cem Subakan <csubakan@gmail.com> Co-authored-by: Francesco Paissan <fpaissan@cedar1.cedar.computecanada.ca> Co-authored-by: asu <sdelang@sdelang.fr> Co-authored-by: Pooneh Mousavi <moosavi.pooneh@gmail.com> Co-authored-by: Mirco Ravanelli <mirco.ravanelli@gmail.com> Co-authored-by: shucongzhang <104781888+shucongzhang@users.noreply.github.com> Co-authored-by: Shucong Zhang/Embedded AI /SRUK/Engineer/Samsung Electronics <s1.zhang@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com> Co-authored-by: Adel Moumen <88119391+Adel-Moumen@users.noreply.github.com> Co-authored-by: Parcollet Titouan <titouan.parcollet@univ-avignon.fr> Co-authored-by: Parcollet Titouan <parcollet.titouan@gmail.com> Co-authored-by: Titouan Parcollet/Embedded AI /SRUK/Engineer/Samsung Electronics <t.parcollet@sruk-ccn4.eu.corp.samsungelectronics.net> Co-authored-by: Yingzhi WANG <41187612+BenoitWang@users.noreply.github.com> Co-authored-by: Peter Plantinga <plantinga.peter@protonmail.com> Co-authored-by: Séverin <123748182+SevKod@users.noreply.github.com>
1 parent 0691acd commit 2f9e2fd

32 files changed

+2583
-1164
lines changed

recipes/ESC50/classification/README.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
# Sound Classification - ESC50 Dataset
22

33
This recipe trains a classifier for the ESC50 multiclass sound classification dataset.
4-
It is mainly adapted from the Speechbrain UrbanSound8k recipe.
4+
5+
The task involves classifying audio sounds into 50 different categories. These categories are divided into the following groups:
6+
7+
- Animals
8+
- Natural soundscapes and water sounds
9+
- Human, non-speech sounds
10+
- Interior/domestic sounds
11+
- Exterior/urban noises
512

613
The scripts offer the possibility to train both with log-spectra and log-mel audio features.
714

15+
## Dataset Download
16+
17+
The ESC50 dataset will be automatically downloaded when running the recipe. If you prefer to download it manually, please visit: [https://github.com/karolpiczak/ESC-50](https://github.com/karolpiczak/ESC-50)
18+
19+
820
---------------------------------------------------------------------------------------------------------
921

1022
## Installing Extra Dependencies
@@ -29,6 +41,8 @@ This script trains a [CNN14 model](https://arxiv.org/abs/1912.10211) on the ESC5
2941
python train.py hparams/cnn14.yaml --data_folder /yourpath/ESC50
3042
```
3143

44+
The dataset will be automatically download at the specified data folder.
45+
3246
---------------------------------------------------------------------------------------------------------
3347

3448
### Conv2D
@@ -61,6 +75,16 @@ python train.py hparams/vit.yaml --data_folder /yourpath/ESC50
6175

6276
---------------------------------------------------------------------------------------------------------
6377

78+
### To train with WHAM! noise
79+
80+
In order to train the classifier with WHAM! noise, you need to download the WHAM! noise dataset from [here](http://wham.whisper.ai/).
81+
Then, you can train your classifier with the following command:
82+
83+
```shell
84+
python train.py hparams/modelofchoice.yaml --data_folder /yourpath/ESC50 --add_wham_noise True --wham_folder /yourpath/wham_noise
85+
```
86+
87+
6488
## Results
6589

6690
| Hyperparams file | Accuracy (%) | Training time | HuggingFace link | Model link | GPUs |
@@ -139,4 +163,4 @@ If you use **SpeechBrain**, please cite:
139163
- Code: https://github.com/speechbrain/speechbrain/
140164
- HuggingFace: https://huggingface.co/speechbrain/
141165

142-
---------------------------------------------------------------------------------------------------------
166+
---------------------------------------------------------------------------------------------------------
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
matplotlib
2+
pandas
23
scikit-learn
34
torchvision
5+
transformers
6+
wget

recipes/ESC50/classification/hparams/cnn14.yaml

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# Authors:
66
# * Cem Subakan 2022, 2023
7-
# * Francesco Paissan 2022, 2023
7+
# * Francesco Paissan 2022, 2023, 2024
88
# (based on the SpeechBrain UrbanSound8k recipe)
99
# #################################
1010

@@ -16,11 +16,20 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
1616
data_folder: !PLACEHOLDER # e.g., /localscratch/ESC-50-master
1717
audio_data_folder: !ref <data_folder>/audio
1818

19-
experiment_name: cnn14-esc50
19+
experiment_name: !ref cnn14-esc50
2020
output_folder: !ref ./results/<experiment_name>/<seed>
2121
save_folder: !ref <output_folder>/save
2222
train_log: !ref <output_folder>/train_log.txt
2323

24+
add_wham_noise: False
25+
test_only: False
26+
27+
wham_folder: null # Set it if add_wham_noise is True.
28+
wham_audio_folder: !ref <wham_folder>/tr
29+
30+
31+
sample_rate: 16000
32+
signal_length_s: 5
2433

2534
# Tensorboard logs
2635
use_tensorboard: False
@@ -47,9 +56,7 @@ lr: 0.0002
4756
base_lr: 0.00000001
4857
max_lr: !ref <lr>
4958
step_size: 65000
50-
sample_rate: 44100
5159

52-
device: "cpu"
5360

5461
# Feature parameters
5562
n_mels: 80
@@ -58,6 +65,7 @@ right_frames: 0
5865
deltas: False
5966

6067
use_melspectra: True
68+
use_log1p_mel: True
6169

6270
# Number of classes
6371
out_n_neurons: 50
@@ -84,10 +92,9 @@ embedding_model: !new:speechbrain.lobes.models.Cnn14.Cnn14
8492
mel_bins: !ref <n_mels>
8593
emb_dim: 2048
8694

87-
classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
88-
input_size: 2048
89-
out_neurons: !ref <out_n_neurons>
90-
lin_blocks: 1
95+
classifier: !new:torch.nn.Linear
96+
in_features: 2048
97+
out_features: !ref <out_n_neurons>
9198

9299
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
93100
limit: !ref <number_of_epochs>
@@ -107,6 +114,7 @@ compute_fbank: !new:speechbrain.processing.features.Filterbank
107114
n_mels: 80
108115
n_fft: !ref <n_fft>
109116
sample_rate: !ref <sample_rate>
117+
log_mel: False
110118

111119
modules:
112120
compute_stft: !ref <compute_stft>
@@ -145,7 +153,9 @@ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
145153
counter: !ref <epoch_counter>
146154

147155
use_pretrained: True
148-
# If you do not want to use the pretrained encoder you can simply delete pretrained_encoder field.
156+
# If you do not want to use the pretrained encoder
157+
# you can simply delete pretrained_encoder field,
158+
# or set use_pretrained=False
149159
embedding_model_path: speechbrain/cnn14-esc50/embedding_model.ckpt
150160
pretrained_encoder: !new:speechbrain.utils.parameter_transfer.Pretrainer
151161
collect_in: !ref <save_folder>

recipes/ESC50/classification/hparams/conv2d.yaml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
1616
data_folder: !PLACEHOLDER # e.g., /localscratch/ESC-50-master
1717
audio_data_folder: !ref <data_folder>/audio
1818

19+
wham_folder: null # Set it if add_wham_noise is True
20+
wham_audio_folder: !ref <wham_folder>/tr
21+
1922
experiment_name: conv2dv2_classifier-16k
2023
output_folder: !ref ./results/<experiment_name>/<seed>
2124
save_folder: !ref <output_folder>/save
2225
train_log: !ref <output_folder>/train_log.txt
2326

27+
test_only: False
2428

2529
# Tensorboard logs
2630
use_tensorboard: False
@@ -48,8 +52,10 @@ base_lr: 0.000002
4852
max_lr: !ref <lr>
4953
step_size: 65000
5054
sample_rate: 16000
55+
signal_length_s: 5
56+
57+
add_wham_noise: False
5158

52-
device: "cpu"
5359

5460
# Feature parameters
5561
n_mels: 80
@@ -65,6 +71,7 @@ dataloader_options:
6571

6672
use_pretrained: True
6773
use_melspectra: False
74+
use_log1p_mel: False
6875
embedding_model: !new:speechbrain.lobes.models.PIQ.Conv2dEncoder_v2
6976
dim: 256
7077

@@ -73,6 +80,10 @@ classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
7380
out_neurons: !ref <out_n_neurons>
7481
lin_blocks: 1
7582

83+
#classifier: !new:torch.nn.Linear
84+
#in_features: 256
85+
#out_features: !ref <out_n_neurons>
86+
7687
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
7788
limit: !ref <number_of_epochs>
7889

recipes/ESC50/classification/hparams/focalnet.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ output_folder: !ref ./results/<experiment_name>/<seed>
2222
save_folder: !ref <output_folder>/save
2323
train_log: !ref <output_folder>/train_log.txt
2424

25+
add_wham_noise: False
26+
test_only: False
27+
28+
wham_folder: null # Set it if add_wham_noise is True
29+
wham_audio_folder: !ref <wham_folder>/tr
30+
31+
use_melspectra: False
32+
use_log1p_mel: False
33+
2534
# Tensorboard logs
2635
use_tensorboard: False
2736
tensorboard_logs_folder: !ref <output_folder>/tb_logs/
@@ -49,6 +58,8 @@ max_lr: !ref <lr>
4958
step_size: 65000
5059
sample_rate: 16000
5160

61+
signal_length_s: 5
62+
5263
# Number of classes
5364
out_n_neurons: 50
5465

recipes/ESC50/classification/hparams/vit.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ output_folder: !ref ./results/<experiment_name>/<seed>
2222
save_folder: !ref <output_folder>/save
2323
train_log: !ref <output_folder>/train_log.txt
2424

25+
add_wham_noise: False
26+
use_melspectra: False
27+
use_log1p_mel: False
28+
test_only: False
29+
30+
wham_folder: null # Set it if add_wham_noise is True
31+
wham_audio_folder: !ref <wham_folder>/tr
32+
2533
# Tensorboard logs
2634
use_tensorboard: False
2735
tensorboard_logs_folder: !ref <output_folder>/tb_logs/
@@ -47,7 +55,9 @@ lr: 0.0002
4755
base_lr: 0.00000001
4856
max_lr: !ref <lr>
4957
step_size: 65000
58+
5059
sample_rate: 16000
60+
signal_length_s: 5
5161

5262
# Number of classes
5363
out_n_neurons: 50

recipes/ESC50/classification/train.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,14 @@
2020

2121
import numpy as np
2222
import torch
23+
import torch.nn.functional as F
2324
import torchaudio
2425
import torchvision
2526
from confusion_matrix_fig import create_cm_fig
27+
from esc50_prepare import prepare_esc50
2628
from hyperpyyaml import load_hyperpyyaml
2729
from sklearn.metrics import confusion_matrix
30+
from wham_prepare import combine_batches, prepare_wham
2831

2932
import speechbrain as sb
3033
from speechbrain.utils.distributed import run_on_main
@@ -42,18 +45,23 @@ def compute_forward(self, batch, stage):
4245
if hasattr(self.hparams, "augmentation") and stage == sb.Stage.TRAIN:
4346
wavs, lens = self.hparams.augmentation(wavs, lens)
4447

45-
# Extract features
48+
# augment batch with WHAM!
49+
if hasattr(self.hparams, "add_wham_noise"):
50+
if self.hparams.add_wham_noise:
51+
wavs = combine_batches(wavs, iter(self.hparams.wham_dataset))
52+
4653
X_stft = self.modules.compute_stft(wavs)
47-
X_stft_power = sb.processing.features.spectral_magnitude(
54+
net_input = sb.processing.features.spectral_magnitude(
4855
X_stft, power=self.hparams.spec_mag_power
4956
)
5057
if (
5158
hasattr(self.hparams, "use_melspectra")
5259
and self.hparams.use_melspectra
5360
):
54-
net_input = self.modules.compute_fbank(X_stft_power)
55-
else:
56-
net_input = torch.log1p(X_stft_power)
61+
net_input = self.modules.compute_fbank(net_input)
62+
63+
if (not self.hparams.use_melspectra) or self.hparams.use_log1p_mel:
64+
net_input = torch.log1p(net_input)
5765

5866
# Embeddings + sound classifier
5967
if hasattr(self.modules.embedding_model, "config"):
@@ -80,11 +88,18 @@ def compute_forward(self, batch, stage):
8088
else:
8189
# SpeechBrain model
8290
embeddings = self.modules.embedding_model(net_input)
91+
if isinstance(embeddings, tuple):
92+
embeddings, _ = embeddings
93+
8394
if embeddings.ndim == 4:
8495
embeddings = embeddings.mean((-1, -2))
8596

97+
# run through classifier
8698
outputs = self.modules.classifier(embeddings)
8799

100+
if outputs.ndim == 2:
101+
outputs = outputs.unsqueeze(1)
102+
88103
return outputs, lens
89104

90105
def compute_objectives(self, predictions, batch, stage):
@@ -93,7 +108,17 @@ def compute_objectives(self, predictions, batch, stage):
93108
uttid = batch.id
94109
classid, _ = batch.class_string_encoded
95110

96-
loss = self.hparams.compute_cost(predictions, classid, lens)
111+
# Target augmentation
112+
N_augments = int(predictions.shape[0] / classid.shape[0])
113+
classid = torch.cat(N_augments * [classid], dim=0)
114+
115+
# loss = self.hparams.compute_cost(predictions.squeeze(1), classid, lens)
116+
target = F.one_hot(
117+
classid.squeeze(), num_classes=self.hparams.out_n_neurons
118+
)
119+
loss = (
120+
-(F.log_softmax(predictions.squeeze(1), 1) * target).sum(1).mean()
121+
)
97122

98123
if stage != sb.Stage.TEST:
99124
if hasattr(self.hparams.lr_annealing, "on_batch_end"):
@@ -378,8 +403,6 @@ def label_pipeline(class_string):
378403
hparams["tensorboard_logs_folder"]
379404
)
380405

381-
from esc50_prepare import prepare_esc50
382-
383406
run_on_main(
384407
prepare_esc50,
385408
kwargs={
@@ -399,6 +422,18 @@ def label_pipeline(class_string):
399422
datasets, label_encoder = dataio_prep(hparams)
400423
hparams["label_encoder"] = label_encoder
401424

425+
if "wham_folder" in hparams:
426+
hparams["wham_dataset"] = prepare_wham(
427+
hparams["wham_folder"],
428+
hparams["add_wham_noise"],
429+
hparams["sample_rate"],
430+
hparams["signal_length_s"],
431+
hparams["wham_audio_folder"],
432+
)
433+
434+
if hparams["wham_dataset"] is not None:
435+
assert hparams["signal_length_s"] == 5, "Fix wham sig length!"
436+
402437
class_labels = list(label_encoder.ind2lab.values())
403438
print("Class Labels:", class_labels)
404439

@@ -411,17 +446,21 @@ def label_pipeline(class_string):
411446
)
412447

413448
# Load pretrained encoder if it exists in the yaml file
449+
if not hasattr(ESC50_brain.modules, "embedding_model"):
450+
ESC50_brain.hparams.embedding_model.to(ESC50_brain.device)
451+
414452
if "pretrained_encoder" in hparams and hparams["use_pretrained"]:
415453
run_on_main(hparams["pretrained_encoder"].collect_files)
416454
hparams["pretrained_encoder"].load_collected()
417455

418-
ESC50_brain.fit(
419-
epoch_counter=ESC50_brain.hparams.epoch_counter,
420-
train_set=datasets["train"],
421-
valid_set=datasets["valid"],
422-
train_loader_kwargs=hparams["dataloader_options"],
423-
valid_loader_kwargs=hparams["dataloader_options"],
424-
)
456+
if not hparams["test_only"]:
457+
ESC50_brain.fit(
458+
epoch_counter=ESC50_brain.hparams.epoch_counter,
459+
train_set=datasets["train"],
460+
valid_set=datasets["valid"],
461+
train_loader_kwargs=hparams["dataloader_options"],
462+
valid_loader_kwargs=hparams["dataloader_options"],
463+
)
425464

426465
# Load the best checkpoint for evaluation
427466
test_stats = ESC50_brain.evaluate(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../wham_prepare.py

0 commit comments

Comments
 (0)