Skip to content

Commit 8c6f8ef

Browse files
ZhaoZeyu1995pchampioAdel-Moumenmravanelli
authored
[k2 integration] New PR copied from #2133 (#2345)
* Move k2-ZG to this new branch based on develop. * Fix `utt ids <-> utt texts` when batch decoding > 1 * ngrams * lm_dir / tests * update gitignore/yaml * fix rescoring_lm_scale * precommit * test * update prune_level * precommit * readme * kaldilm installation * remoev pyctcdecode * update ngram * readme * Resolve docstring issues found by Adel. * update res * update results/ckpt * small update README * test passed k2 w2v * solve ci * fix ci / pre-commit * update ci * ngram update tests passed * Update train_with_wav2vec_k2.py * Update train_with_wav2vec_k2.py --------- Co-authored-by: pchampio <prr.champion@gmail.com> Co-authored-by: Adel Moumen <adelmoumen.pro@gmail.com> Co-authored-by: Mirco Ravanelli <mirco.ravanelli@gmail.com> Co-authored-by: Adel Moumen <88119391+Adel-Moumen@users.noreply.github.com>
1 parent 52b2884 commit 8c6f8ef

File tree

22 files changed

+3892
-11
lines changed

22 files changed

+3892
-11
lines changed

.github/workflows/pythonapp.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,14 @@ jobs:
3636
- name: Full dependencies
3737
run: |
3838
sudo apt-get update
39+
# up to k2 compatible torch version
40+
pip install torch==2.1.2 torchaudio==2.1.2
3941
pip install -r requirements.txt
4042
pip install --editable .
4143
pip install ctc-segmentation
44+
pip install k2==1.24.4.dev20231220+cpu.torch2.1.2 -f https://k2-fsa.github.io/k2/cpu.html
4245
pip install protobuf
46+
pip install kaldilm==1.15
4347
- name: Consistency tests with pytest
4448
run: |
4549
pytest tests/consistency

.github/workflows/verify-docs-gen.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@ jobs:
1717
python-version: '3.8'
1818
- name: Full dependencies
1919
run: |
20+
# up to k2 compatible torch version
21+
pip install torch==2.1.2 torchaudio==2.1.2
2022
pip install -r requirements.txt
2123
pip install --editable .
2224
pip install -r docs/docs-requirements.txt
25+
pip install k2==1.24.4.dev20231220+cpu.torch2.1.2 -f https://k2-fsa.github.io/k2/cpu.html
2326
- name: Generate docs
2427
run: |
2528
cd docs

recipes/LibriSpeech/ASR/CTC/README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# LibriSpeech ASR with CTC and pre-trained wav2vec2 or whisper models.
22
This folder contains the scripts to finetune a wav2vec2 or a whisper based system using LibriSpeech.
33
You can download LibriSpeech at http://www.openslr.org/12.
4+
The loss function is the CTC loss and it is implemented in two different ways:
5+
- Using the [CTCLoss](https://pytorch.org/docs/stable/generated/torch.nn.CTCLoss.html) from PyTorch.
6+
- Using the [CTC implementation](https://github.com/k2-fsa/k2/blob/master/k2/python/k2/ctc_loss.py) from K2 (WFST-based). For an example of such recipe, check the `train_with_wav2vec_k2.py` file.
47

58
**Supported pre-trained wav2vec2:** [SpeechBrain](https://github.com/speechbrain/speechbrain/tree/develop/recipes/LibriSpeech/self-supervised-learning/wav2vec2) and [HuggingFace](https://github.com/speechbrain/speechbrain/tree/develop/recipes/CommonVoice/self-supervised-learning/wav2vec2)
69

@@ -25,6 +28,24 @@ To run a fine-tuning of "WavLM" with signal downsampled inputs (for faster train
2528
python train_with_wav2vec.py hparams/downsampled/train_hf_wavlm_signal_downsampling.yaml --downsampling_factor 2
2629
```
2730

31+
# WFST-based CTC loss
32+
To fine-tune a wav2vec 2.0 model with the WFST-based CTC loss, you can use the `train_with_wav2vec_k2.py` script. This will create a `lang` directory inside your output folder, which will contain the files required to build a lexicon FST. The tokenization method used here is a very basic character-based tokenization (e.g. `hello -> h e l l o`).
33+
34+
To use this script, you will first need to install `k2`. The integration has been tested with `k2==1.24.4` and `torch==2.0.1`, although it should also work with any `torch` version as long as `k2` supports it (compatibility list [here](https://k2-fsa.github.io/k2/installation/pre-compiled-cuda-wheels-linux/index.html)). You can install `k2` by following the instructions [here](https://k2-fsa.github.io/k2/installation/from_wheels.html#linux-cuda-example).
35+
36+
Using a lexicon FST (L) while training can help guide the model to better predictions. When decoding, you can either use a simple HL decoding graph (where H is the ctc topology), or use an HLG graph (where G is usually a 3-gram language model) to further improve the results. In addition, whole lattice rescoring is also supported. This typically happens with a 4-gram language model. See `hparams/train_with_wav2vec_k2.yaml`` for more details.
37+
38+
If you choose to use a 3-gram or a 4-gram language model, you can either supply pre-existing ARPA LMs for both cases, including the option to train your own, or you can specify the name in the YAML docstring for automatic downloading. Comprehensive instructions are provided in `train_hf_wav2vec_k2.yaml`.
39+
40+
For those interested in training their own language model, please consult our recipe at LibriSpeech/LM/train_ngram.py.
41+
42+
Example usage:
43+
```
44+
python train_with_wav2vec_k2.py hparams/train_hf_wav2vec_k2.yaml --data_folder=/path/to/LibriSpeech
45+
```
46+
47+
To use the HLG graph (instead of the default HL), pass `--compose_HL_with_G=True`. To use the 4-gram LM for rescoring, pass the `--decoding_method=whole-lattice-rescoring` argument. Note that this will require more memory, as the whole lattice will be kept in memory during the decoding. In this recipe, the `lm_scale` used by default is 0.4. This is the value that gave the best results in our HL-graph experiments after trying scales of `[0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4]`. When rescoring is used alongside the HLG graph, the 4-gram seems to not bring any improvement. The best lm scale in that case was 0.2 (the lowest value we tried).
48+
2849
# KenLM n-gram CTC rescoring
2950
To enable n-gram rescoring during the decoding, you can download the LibriSpeech official LM from [here](https://www.openslr.org/11/). Please make sure to install the extra dependencies first. Any KenLM language model may be used with this rescoring technique. The n-gram can either be a binary or an arpa file, but note that the binary format is faster to load. The following command shows how to use the official LibriSpeech 4-gram LM with SpeechBrain:
3051
```bash
@@ -57,6 +78,10 @@ Note: by default, `topk` is set to 20 as it gives a good trade-off between WER a
5778
| 08-12-23 | train_hf_whisper.yaml (small) | CTCPrefixBeamSearch + test batch size = 1 | 960h | 4.73 | 3.19 | 12.65 |3.39 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) | 1xRTX3090 24GB | 2xTesla V100 32GB |
5879
| 08-12-23 | train_hf_whisper.yaml (small) | CTCBeamSearch + 4-gram + test batch size = 1 | 960h | 4.37 | 3.16 | 11.76 | 3.43 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) | 1xRTX3090 24GB | 2xTesla V100 32GB |
5980
| 08-12-23 | train_hf_whisper.yaml (small) | CTCPrefixBeamSearch + 4-gram + test batch size = 1 | 960h | 4.44 | 3.30 | 11.89 | 3.47 | Not Avail. | [Link](https://www.dropbox.com/sh/zmtp13huxn02fot/AADyKL5q0MwRhEG1-WbSXDWda?dl=0) | 1xRTX3090 24GB | 2xTesla V100 32GB |
81+
| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HL graph + 1best decoding + test batch size = 1 | 960h | 1.83 | Not Avail. | 3.82 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/678rj1a44jt4zrxjwaetu/h?rlkey=x0xwz31nkl01qwr3k5ivtywvz&dl=0) | 1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
82+
| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HLG graph + 1best decoding + test batch size = 1 | 960h | 1.69 | Not Avail. | 3.44 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/c91vqlr8ase90x0m7u3v3/h?rlkey=duh55n0qzlfnfhy4auu0a4f8g&dl=0) | 1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
83+
| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HL graph + whole lattice rescoring + test batch size = 1 | 960h | 1.72 | Not Avail. | 3.51 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/mx6hd4zc0iyzqvixxre6q/h?rlkey=xxbpb949btmeiecw30be5qwhj&dl=0) | 1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
84+
| 23-01-24 | train_hf_wav2vec_k2.yaml | k2CTC + HLG graph + whole lattice rescoring + test batch size = 1 | 960h | 1.81 | Not Avail. | 3.57 | Not Avail. | Not Avail. | [Link](https://www.dropbox.com/scl/fo/kj2ujqj3votq7ue6ydh0l/h?rlkey=mibyoria19zasvuxs0iwx6plt&dl=0) | 1xRTX2080Ti 12GB | 1xRTX2080Ti 12GB |
6085
| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + RNNLM Rescorer + test batch size = 1 + topk = 100 | 960h | 1.69 | 26mins15 | 3.55 | 32min44s | Not Avail. | [Link](https://www.dropbox.com/sh/k4ixa211yp5b1tm/AAD85sgYw2CH7NKk_qKMO9Tja?dl=0) | 1x A100 40GB | 2xTesla V100 40GB |
6186
| 08-12-23 | train_hf_wav2vec.yaml | CTCBeamSearch + TransformerLM Rescorer + test batch size = 1 + topk = 100 | 960h | 1.57 | 26mins56s | 3.37 | 32min46 | Not Avail. | [Link](https://www.dropbox.com/sh/ijqalvre7mm08ng/AAD_hsN-8dBneUMMkELsOOxga?dl=0) | 1x A100 40GB | 2xTesla V100 32GB |
6287

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1+
# k2 # It is better to install k2 with the procedure listed here: https://k2-fsa.github.io/k2/installation/from_wheels.html
2+
kaldilm==1.15
13
kenlm
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
# ################################
2+
# Model: wav2vec2 + DNN + CTC + LM (k2)
3+
# Augmentation: SpecAugment
4+
#
5+
# This recipe trains a wav2vec2 model with a DNN and DWFST-based CTC loss.
6+
# To use this recipe you need to have the following:
7+
# - A folder with the LibriSpeech dataset (see `datafolder`)
8+
# - A folder with a small, and (optionally) a big LM (see `lm_dir`)
9+
# These can be downloaded in ARPA format from: http://www.openslr.org/resources/11/.
10+
# - A working installation of k2 (and kaldilm if you want to use ARPA LMs).
11+
#
12+
# Authors: Zeyu Zhao 2023
13+
# Georgios Karakasidis 2023
14+
# Pierre Champion 2023
15+
# ################################
16+
17+
# Seed needs to be set at top of yaml, before objects with parameters are made
18+
seed: 1111
19+
__set_seed: !apply:torch.manual_seed [!ref <seed>]
20+
output_folder: !ref results/train_wav2vec2_char_k2/<seed>
21+
output_wer_folder: !ref <output_folder>/
22+
save_folder: !ref <output_folder>/save
23+
train_log: !ref <output_folder>/train_log.txt
24+
25+
# URL for the biggest Fairseq english wav2vec2 model.
26+
wav2vec2_hub: facebook/wav2vec2-large-960h-lv60-self
27+
wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
28+
29+
# Data files
30+
data_folder: !PLACEHOLDER # e,g./path/to/LibriSpeech
31+
# noise/ris dataset will automatically be downloaded
32+
# data_folder_rirs: !ref <data_folder>
33+
train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
34+
dev_splits: ["dev-clean", "dev-other"]
35+
test_splits: ["test-clean", "test-other"]
36+
skip_prep: False
37+
ckpt_interval_minutes: 25 # save checkpoint every N min
38+
train_csv: !ref <output_folder>/train.csv
39+
valid_csv: !ref <output_folder>/dev-clean.csv
40+
test_csv:
41+
- !ref <output_folder>/test-clean.csv
42+
- !ref <output_folder>/test-other.csv
43+
- !ref <output_folder>/dev-clean.csv
44+
- !ref <output_folder>/dev-other.csv
45+
46+
# For k2 CTC training
47+
lang_dir: !ref <output_folder>/lang
48+
vocab_file: !ref <data_folder>/librispeech-vocab.txt
49+
sil_prob: 0.
50+
add_word_boundary: True
51+
# For k2 decoding
52+
test_search_beam: 32
53+
# Beam size (for decoding)
54+
test_output_beam: 8
55+
test_min_active_state: 300
56+
test_max_active_state: 3000
57+
# Acoustic scale (mutliplied by the log probs)
58+
ac_scale: 1.5
59+
compose_HL_with_G: False
60+
# 1best or whole-lattice-rescoring
61+
# decoding_method: whole-lattice-rescoring
62+
decoding_method: 1best
63+
# LM scale to be used for rescoring. Only used if rescoring
64+
rescoring_lm_scale: 0.4
65+
# This is where the 3gram and (optionally) 4gram LM are stored
66+
# They can be in either ARPA or FST format. If the former, then
67+
# the FST equivalent will be created in the same directory by
68+
# using kaldilm.
69+
lm_dir: !ref <output_folder>/lm
70+
# The ARPA LM files are located under the lm_dir.
71+
# - Use (recommended):
72+
# - 3-gram_sb.arpa
73+
# - 4-gram_sb.arpa
74+
# To downloads speechbrain pretrained models (trained on train-960+librispeech-lm-norm.txt, 214k words)
75+
# - Use:
76+
# - 3-gram.arpa
77+
# - 3-gram.pruned.1e-7.arpa
78+
# - 3-gram.pruned.3e-7.arpa
79+
# - 4-gram.arpa
80+
# To downloads http://www.openslr.org/resources/11/ pretrained models (trained on librispeech-lm-norm.txt, 200k words)
81+
# - Use another name for a model you trained yourself.
82+
# If the arpa does not exist in the lm_dir, you'll need to train it yourself.
83+
# Please see LibriSpeech/LM/README.md for instructions.
84+
# Using one of the above name will automatically download the corresponding model.
85+
# You can speciy a different name, but you'll need to make sure the file exists in the lm_dir.
86+
# Make sure to use enough RAM and CPUs as the conversion to FST can be quite demanding.
87+
G_arpa: 3-gram_sb.arpa
88+
G_rescoring_arpa: 4-gram_sb.arpa
89+
# caching: False
90+
91+
# Training parameters
92+
number_of_epochs: 1
93+
lr: 0.9
94+
lr_wav2vec: 0.0001
95+
sorting: ascending # only ascending and descending are supported currently
96+
precision: fp32
97+
sample_rate: 16000
98+
99+
# With data_parallel batch_size is split into N jobs
100+
# With DDP batch_size is multiplied by N jobs
101+
# Must be 3 per GPU to fit 32GB of VRAM
102+
batch_size: 6
103+
test_batch_size: 1
104+
num_workers: 10
105+
106+
# Dataloader options
107+
train_dataloader_opts:
108+
batch_size: !ref <batch_size>
109+
num_workers: !ref <num_workers>
110+
111+
valid_dataloader_opts:
112+
batch_size: !ref <batch_size>
113+
num_workers: !ref <num_workers>
114+
115+
test_dataloader_opts:
116+
batch_size: !ref <test_batch_size>
117+
num_workers: !ref <num_workers>
118+
119+
# Model parameters
120+
activation: !name:torch.nn.LeakyReLU
121+
dnn_layers: 2
122+
dnn_neurons: 1024
123+
freeze_wav2vec: True
124+
125+
# Outputs
126+
output_neurons: 30 # BPE size, index(blank/eos/bos) = 0
127+
128+
#
129+
# Functions and classes
130+
#
131+
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
132+
limit: !ref <number_of_epochs>
133+
134+
speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
135+
orig_freq: !ref <sample_rate>
136+
speeds: [95, 100, 105]
137+
138+
# Frequency drop: randomly drops a number of frequency bands to zero.
139+
drop_freq_low: 0 # Min frequency band dropout probability
140+
drop_freq_high: 1 # Max frequency band dropout probability
141+
drop_freq_count_low: 1 # Min number of frequency bands to drop
142+
drop_freq_count_high: 3 # Max number of frequency bands to drop
143+
drop_freq_width: 0.05 # Width of frequency bands to drop
144+
145+
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
146+
drop_freq_low: !ref <drop_freq_low>
147+
drop_freq_high: !ref <drop_freq_high>
148+
drop_freq_count_low: !ref <drop_freq_count_low>
149+
drop_freq_count_high: !ref <drop_freq_count_high>
150+
drop_freq_width: !ref <drop_freq_width>
151+
152+
# Time drop: randomly drops a number of temporal chunks.
153+
drop_chunk_count_low: 1 # Min number of audio chunks to drop
154+
drop_chunk_count_high: 5 # Max number of audio chunks to drop
155+
drop_chunk_length_low: 1000 # Min length of audio chunks to drop
156+
drop_chunk_length_high: 2000 # Max length of audio chunks to drop
157+
158+
drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
159+
drop_length_low: !ref <drop_chunk_length_low>
160+
drop_length_high: !ref <drop_chunk_length_high>
161+
drop_count_low: !ref <drop_chunk_count_low>
162+
drop_count_high: !ref <drop_chunk_count_high>
163+
164+
# Augmenter: Combines previously defined augmentations to perform data augmentation
165+
wav_augment: !new:speechbrain.augment.augmenter.Augmenter
166+
parallel_augment: False
167+
repeat_augment: 1
168+
shuffle_augmentations: False
169+
min_augmentations: 4
170+
max_augmentations: 4
171+
augment_prob: 1.0
172+
augmentations: [
173+
!ref <speed_perturb>,
174+
!ref <drop_freq>,
175+
!ref <drop_chunk>]
176+
177+
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
178+
input_shape: [null, null, 1024]
179+
activation: !ref <activation>
180+
dnn_blocks: !ref <dnn_layers>
181+
dnn_neurons: !ref <dnn_neurons>
182+
183+
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.Wav2Vec2
184+
source: !ref <wav2vec2_hub>
185+
output_norm: True
186+
freeze: !ref <freeze_wav2vec>
187+
save_path: !ref <wav2vec2_folder>
188+
189+
#####
190+
# Uncomment this block if you prefer to use a Fairseq pretrained model instead
191+
# of a HuggingFace one. Here, we provide an URL that is obtained from the
192+
# Fairseq github for the multilingual XLSR.
193+
#
194+
#wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
195+
#wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
196+
# pretrained_path: !ref <wav2vec2_url>
197+
# output_norm: True
198+
# freeze: False
199+
# save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
200+
201+
ctc_lin: !new:speechbrain.nnet.linear.Linear
202+
input_size: !ref <dnn_neurons>
203+
n_neurons: !ref <output_neurons>
204+
205+
log_softmax: !new:speechbrain.nnet.activations.Softmax
206+
apply_log: True
207+
208+
ctc_cost: !name:speechbrain.k2_integration.losses.ctc_k2
209+
reduction: mean
210+
beam_size: 10
211+
212+
modules:
213+
wav2vec2: !ref <wav2vec2>
214+
enc: !ref <enc>
215+
ctc_lin: !ref <ctc_lin>
216+
217+
model: !new:torch.nn.ModuleList
218+
- [!ref <enc>, !ref <ctc_lin>]
219+
220+
model_opt_class: !name:torch.optim.Adadelta
221+
lr: !ref <lr>
222+
rho: 0.95
223+
eps: 1.e-8
224+
225+
wav2vec_opt_class: !name:torch.optim.Adam
226+
lr: !ref <lr_wav2vec>
227+
228+
lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
229+
initial_value: !ref <lr>
230+
improvement_threshold: 0.0025
231+
annealing_factor: 0.8
232+
patient: 0
233+
234+
lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
235+
initial_value: !ref <lr_wav2vec>
236+
improvement_threshold: 0.0025
237+
annealing_factor: 0.9
238+
patient: 0
239+
240+
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
241+
checkpoints_dir: !ref <save_folder>
242+
recoverables:
243+
wav2vec2: !ref <wav2vec2>
244+
model: !ref <model>
245+
scheduler_model: !ref <lr_annealing_model>
246+
scheduler_wav2vec: !ref <lr_annealing_wav2vec>
247+
counter: !ref <epoch_counter>
248+
249+
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
250+
save_file: !ref <train_log>
251+
252+
error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
253+
254+
cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
255+
split_tokens: True

0 commit comments

Comments
 (0)