Skip to content

Commit 39ef358

Browse files
authored
SpeechLLM LibriSpeech recipe (#2885)
1 parent de1c94d commit 39ef358

21 files changed

Lines changed: 4615 additions & 2216 deletions

File tree

docs/tutorials/basics/data-loading-pipeline.ipynb

Lines changed: 2666 additions & 2163 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ ignore = [
114114
combine-as-imports = true
115115
force-wrap-aliases = true
116116
known-first-party = ["speechbrain"]
117-
known-third-party = ["torch", "torchaudio", "numpy", "scipy", "hyperpyyaml", "joblib", "packaging", "sentencepiece", "tqdm", "huggingface_hub"]
117+
known-third-party = ["torch", "torchaudio", "numpy", "scipy", "hyperpyyaml", "joblib", "packaging", "requests", "sentencepiece", "tqdm", "huggingface_hub"]
118118
split-on-trailing-comma = false
119119

120120
[tool.ruff.lint.per-file-ignores]

recipes/LibriSpeech/ASR/transformer/README.md

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ You can download LibriSpeech at http://www.openslr.org/12
77
```shell
88
python train_with_whisper.py hparams/train_hf_whisper.yaml
99
python train.py hparams/transformer.yaml
10-
1110
```
1211

1312
# How to run on test sets only
@@ -23,6 +22,20 @@ installed in your environment (see extra-requirements.txt)**
2322

2423
# Results
2524

25+
## SpeechLLM with SSL features
26+
27+
Two SpeechLLM modes are supported:
28+
- SpeechLLM with SSL features
29+
- SpeechLLM with E2E features
30+
31+
In the first mode, the speech features are extracted from the audio waveforms using a pre-trained SSL model, and then projected to the LLM embedding space using a linear layer projection, where everything is trained jointly.
32+
33+
In the second mode, the speech features are already being extracted offline (see: `extract_ssl_feats.py` script). The LLM is then trained on the frozen SSL representations. This mode is more efficient and faster to train, but at the cost of flexibility on the frozen SSL model.
34+
35+
| Release | Model | hyperparams file | Dev Clean WER | Dev Other WER | Test Clean WER | Test Other WER | HuggingFace link | Model link | GPUs |
36+
|:-------------:|:-------------:|:-------------:|:---------------------------:| :-----:| :-----:| :-----:| :-----:| :--------:|
37+
| 29-01-26 | WavLM Large + LLama 3.2 1B + LoRA | speechllm_e2e.yaml | 2.79 | 5.03 | 2.72 | 5.34 | [HuggingFace](https://huggingface.co/speechbrain/asr-wavlm-large-llama3.2-1b-lora-librispeech) | - | 1xA100 80GB |
38+
2639
## Whisper Finetuning Result:
2740

2841
Following table contains whisper-finetuning results for 1 epoch using Whisper model, freezing encoder and finetuning decoder.
@@ -49,25 +62,6 @@ Following table contains whisper-finetuning results for 1 epoch using Whisper mo
4962
| 03-09-23 | hyperbranchformer_25M.yaml | NA | 2.36 | 5.89 | Not Avail. | Not Avail. | 1xP40 24GB
5063
| 05-01-24 | bayesspeech.yaml | 4.28 | 2.84 | 6.27 | Not Avail. | [DropBox](https://www.dropbox.com/scl/fo/cdken4jqfj96ev1v84jxm/h?rlkey=25eu1ytgm5ac51zqj8p65zwxd&dl=0) | 1xV100 32GB |
5164

52-
# **About HyperConformer**
53-
HyperConformer is a new architecture, which replaces the self-attention mechanism of Conformer with the linear-time token mixing architecture HyperMixer.
54-
It achieves competitive or better results than Conformer while requiring less memory and compute.
55-
56-
- Paper: https://arxiv.org/abs/2305.18281
57-
- HyperMixer code: https://github.com/idiap/hypermixing
58-
59-
Please cite HyperConformer if you use it for your research or business.
60-
61-
```bibtex
62-
@inproceedings{mai23_interspeech,
63-
author={Florian Mai and Juan Zuluaga-Gomez and Titouan Parcollet and Petr Motlicek},
64-
title={{HyperConformer}: Multi-head HyperMixer for Efficient Speech Recognition},
65-
year=2023,
66-
booktitle={Proc. Interspeech 2023},
67-
pages={2213--2217},
68-
doi={10.21437/Interspeech.2023-1611}
69-
}
70-
```
7165

7266
# **About SpeechBrain**
7367
- Website: https://speechbrain.github.io/
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#!/usr/bin/env python3
2+
"""Script to extract SSL features from the audio waveforms.
3+
4+
The script uses the `speechbrain.integrations.hdf5.cached_item` module to cache the features.
5+
The cached features are used in the `train_speechllm.py` script to train the SpeechLLM ASR system.
6+
7+
Since we do the extractions within the pipeline in the dataloader, we must place
8+
our hparams elements directly on device, and use a default bsize of 1.
9+
10+
Example
11+
-------
12+
python extract_ssl_feats.py hparams/extract_ssl_feats.yaml
13+
--data_folder path/to/LibriSpeech \
14+
--output_folder path/to/feats_cache \
15+
--ssl_hub path/to/wavlm-large \
16+
--feats_cache_dir path/to/feats_cache
17+
...other_hparams...
18+
19+
Authors
20+
-------
21+
* Adel Moumen, 2025
22+
"""
23+
24+
import sys
25+
from pathlib import Path
26+
27+
import torch
28+
from hyperpyyaml import load_hyperpyyaml
29+
30+
import speechbrain as sb
31+
from speechbrain.integrations.hdf5.cached_item import CachedHDF5DynamicItem
32+
from speechbrain.utils.distributed import run_on_main
33+
from speechbrain.utils.logger import get_logger
34+
35+
logger = get_logger(__name__)
36+
37+
38+
def dataio_prepare(hparams):
39+
"""This function prepares the datasets to be used in the brain class.
40+
It also defines the data processing pipeline through user-defined functions.
41+
"""
42+
data_folder = hparams["data_folder"]
43+
44+
# 2. Define audio pipeline:
45+
@sb.utils.data_pipeline.takes("wav")
46+
@sb.utils.data_pipeline.provides("sig")
47+
def audio_pipeline(wav):
48+
sig = sb.dataio.dataio.read_audio(wav)
49+
return sig
50+
51+
normalizer = hparams["normalize"].to(hparams["device"]).eval()
52+
ssl_encoder = hparams["ssl"].to(hparams["device"]).eval()
53+
54+
# Base compute function used by all cached wrappers (no file bound yet)
55+
@CachedHDF5DynamicItem.cache(hparams["feats_cache_dir"], compression="gzip")
56+
@sb.utils.data_pipeline.takes("id", "sig")
57+
@sb.utils.data_pipeline.provides("feats")
58+
def compute_feats(uid, sig):
59+
sig = sig.to(hparams["device"]).unsqueeze(0)
60+
length = torch.ones(1, device=hparams["device"])
61+
with torch.no_grad(), torch.cuda.amp.autocast(dtype=hparams["dtype"]):
62+
feats = normalizer(sig, length)
63+
feats = ssl_encoder(feats, length)
64+
return feats.squeeze(0).cpu()
65+
66+
dynamic_items = [audio_pipeline, compute_feats]
67+
output_keys = ["id", "sig", "feats"]
68+
69+
train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
70+
csv_path=hparams["train_csv"],
71+
replacements={"data_root": data_folder},
72+
dynamic_items=dynamic_items,
73+
output_keys=output_keys,
74+
)
75+
76+
# Build valid dataset with its own cached wrapper
77+
valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
78+
csv_path=hparams["valid_csv"],
79+
replacements={"data_root": data_folder},
80+
dynamic_items=dynamic_items,
81+
output_keys=output_keys,
82+
)
83+
84+
# test is separate
85+
test_datasets = {}
86+
for csv_file in hparams["test_csv"]:
87+
name = Path(csv_file).stem
88+
test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
89+
csv_path=csv_file,
90+
replacements={"data_root": data_folder},
91+
dynamic_items=dynamic_items,
92+
output_keys=output_keys,
93+
)
94+
95+
datasets = {"train": train_data, "valid": valid_data} | {
96+
k: v for k, v in test_datasets.items()
97+
}
98+
99+
for stage, dataset in datasets.items():
100+
logger.info(f"Iterating {stage} dataset to warm the cache.")
101+
dataset.iterate_once(output_keys=["feats"])
102+
103+
104+
if __name__ == "__main__":
105+
# CLI:
106+
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
107+
with open(hparams_file, encoding="utf-8") as fin:
108+
hparams = load_hyperpyyaml(fin, overrides)
109+
110+
# create ddp_group with the right communication protocol
111+
sb.utils.distributed.ddp_init_group(run_opts)
112+
113+
# 1. # Dataset prep (parsing Librispeech)
114+
from librispeech_prepare import prepare_librispeech # noqa
115+
116+
# Create experiment directory
117+
sb.create_experiment_directory(
118+
experiment_directory=hparams["output_folder"],
119+
hyperparams_to_save=hparams_file,
120+
overrides=overrides,
121+
)
122+
123+
# multi-gpu (ddp) save data preparation
124+
run_on_main(
125+
prepare_librispeech,
126+
kwargs={
127+
"data_folder": hparams["data_folder"],
128+
"tr_splits": hparams["train_splits"],
129+
"dev_splits": hparams["dev_splits"],
130+
"te_splits": hparams["test_splits"],
131+
"save_folder": hparams["output_folder"],
132+
"merge_lst": hparams["train_splits"],
133+
"merge_name": "train.csv",
134+
"skip_prep": hparams["skip_prep"],
135+
},
136+
)
137+
logger.info("Preparing data...")
138+
dataio_prepare(hparams)
139+
logger.info("Done preparing data")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# ############################################################################
2+
# Task : Extraction of self-supervised (SSL) speech features from LibriSpeech
3+
# Usage: Precompute and cache SSL representations for downstream SpeechLLM ASR
4+
# Authors:
5+
# * Adel Moumen, 2025
6+
# ############################################################################
7+
# Seed needs to be set at top of yaml, before objects with parameters are made
8+
seed: 3407
9+
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]
10+
experiment_name: ssl_feats_extraction
11+
output_folder: !ref results/<experiment_name>/<seed>
12+
save_folder: !ref <output_folder>/save
13+
feats_cache_dir: !ref <output_folder>/feats_cache
14+
15+
# Data files
16+
data_folder: !PLACEHOLDER
17+
train_splits: ["train-clean-100", "train-clean-360", "train-other-500"]
18+
dev_splits: ["dev-clean"]
19+
test_splits: ["test-clean", "test-other"]
20+
skip_prep: False
21+
train_csv: !ref <output_folder>/train.csv
22+
valid_csv: !ref <output_folder>/dev-clean.csv
23+
test_csv:
24+
- !ref <output_folder>/test-clean.csv
25+
- !ref <output_folder>/test-other.csv
26+
dtype: !name:torch.bfloat16
27+
device: cuda
28+
29+
####################### Training Parameters ####################################
30+
ssl_hub: !PLACEHOLDER
31+
ssl_folder: !ref <save_folder>/ssl_checkpoint
32+
ssl_frozen: True
33+
34+
####################### Model Components ####################################
35+
normalize: !new:speechbrain.processing.features.InputNormalization
36+
norm_type: sentence
37+
ssl: !new:speechbrain.integrations.huggingface.wav2vec2.Wav2Vec2
38+
source: !ref <ssl_hub>
39+
output_norm: True
40+
freeze: !ref <ssl_frozen>
41+
save_path: !ref <ssl_folder>
42+
device_map: !ref <device>

0 commit comments

Comments
 (0)