Skip to content

Commit 2b3f4f4

Browse files
jrochdiJonas Rochdipplantinga
authored
SGMSE Voicebank Speech Enhancement Recipe (#2947)
Co-authored-by: Jonas Rochdi <1rochdi@informatik.uni-hamburg.de> Co-authored-by: Peter Plantinga <plantinga.peter@protonmail.com>
1 parent 2427785 commit 2b3f4f4

11 files changed

Lines changed: 2006 additions & 2 deletions

File tree

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# VoiceBank Speech Enhancement with SGMSE
2+
This recipe implements a speech enhancement system based on the SGMSE architecture using the VoiceBank dataset (based on the paper: [https://arxiv.org/abs/2208.05830](https://arxiv.org/abs/2208.05830)).
3+
4+
## Results
5+
6+
Experiment Date | PESQ | SI-SDR | STOI
7+
-|-|-|-
8+
2025-07-24 | 2.78 | 17.8 | 95.7
9+
10+
You can find the full experiment folder (i.e., checkpoints, logs, etc) here:
11+
https://www.dropbox.com/scl/fo/bi8sln2de6ep8nrv38jt5/ACWQAOAIsYSMyjhcu2ZSavc?rlkey=xtqlon9xjcy43ghncnlbtruii&st=sql8s5r8&dl=0
12+
13+
## How to Run
14+
### Training
15+
16+
To train the SGMSE speech enhancement model, execute:
17+
18+
```bash
19+
python recipes/Voicebank/enhance/SGMSE/train.py recipes/Voicebank/enhance/SGMSE/hparams.yaml
20+
```
21+
22+
This will:
23+
24+
* Prepare the VoiceBank dataset automatically (if not already prepared).
25+
* Train the model based on hyperparameters defined in `hparams.yaml`.
26+
* Create a `run_name`, unique to each run.
27+
* Store checkpoints, logs, and validation / testing samples in `output_dir/run_name` (specified within the `hparams.yaml` file).
28+
29+
### Resume Training from a previous run
30+
31+
Point --resume to the existing run directory (the folder that contains hyperparams.yaml and checkpoints):
32+
33+
```bash
34+
python recipes/Voicebank/enhance/SGMSE/train.py --resume path/to/results/run_YYYY-MM-DD_HH-MM-SS
35+
```
36+
37+
When --resume is provided:
38+
39+
* The script loads hyperparams.yaml from the given run directory and uses that saved configuration.
40+
* Training continues from the latest checkpoint in that directory (if present), keeping the same run_name.
41+
* CLI overrides still work, but a new run_name is not generated.
42+
43+
44+
### Inference (Speech Enhancement)
45+
You can enhance single audio files or entire directories using a trained model:
46+
47+
* **Single-file enhancement:**
48+
49+
```bash
50+
python recipes/Voicebank/enhance/SGMSE/enhancement.py --run_dir /path/to/trained_model noisy_audio.wav
51+
```
52+
53+
* **Batch enhancement (whole directory):**
54+
55+
```bash
56+
python recipes/Voicebank/enhance/SGMSE/enhancement.py --run_dir /path/to/trained_model /path/to/noisy_directory
57+
```
58+
59+
Enhanced audio files will be stored in a newly created subdirectory specified in `inference_dir` within the `hparams.yaml` file, preserving the original filenames.
60+
61+
## Results and Outputs
62+
During training, all results and model checkpoints are saved in:
63+
64+
```
65+
<output_dir>/<run_name>/
66+
```
67+
68+
During inference, enhanced audio outputs are saved in:
69+
70+
```
71+
<output_dir>/<run_name>/<inference_dir>/
72+
```
73+
74+
## About SpeechBrain
75+
* Website: [https://speechbrain.github.io/](https://speechbrain.github.io/)
76+
* Code: [https://github.com/speechbrain/speechbrain/](https://github.com/speechbrain/speechbrain/)
77+
* HuggingFace: [https://huggingface.co/speechbrain/](https://huggingface.co/speechbrain/)
78+
79+
## Citing SGMSE
80+
```bibtex
81+
@article{richter2023speech,
82+
title={Speech enhancement and dereverberation with diffusion-based generative models},
83+
author={Richter, Julius and Welker, Simon and Lemercier, Jean-Marie and Lay, Bunlong and Gerkmann, Timo},
84+
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
85+
volume={31},
86+
pages={2351--2364},
87+
year={2023},
88+
publisher={IEEE}
89+
}
90+
```
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Single-file or batch speech enhancement with SGMSE.
3+
Single file:
4+
python enhance.py --run_dir /path/to/run noisy.wav
5+
6+
Whole directory:
7+
python enhance.py --run_dir /path/to/run /path/to/noisy_dir
8+
"""
9+
10+
import argparse
11+
import sys
12+
from pathlib import Path
13+
14+
import torch
15+
import torchaudio
16+
from hyperpyyaml import load_hyperpyyaml
17+
from train import SGMSEBrain
18+
19+
from speechbrain.utils.checkpoints import Checkpointer
20+
21+
22+
# Helpers
23+
def is_audio_file(path):
24+
return path.suffix.lower() in {".wav", ".flac", ".ogg"}
25+
26+
27+
def collect_audio_files(src):
28+
return [p for p in src.iterdir() if p.is_file() and is_audio_file(p)]
29+
30+
31+
def main():
32+
parser = argparse.ArgumentParser(
33+
description="Run SGMSE enhancement (torchaudio I/O)"
34+
)
35+
parser.add_argument(
36+
"--run_dir",
37+
"-r",
38+
type=Path,
39+
required=True,
40+
help="Path to the trained run directory (the folder that "
41+
"contains hyperparams.yaml and checkpoints/).",
42+
)
43+
parser.add_argument(
44+
"input",
45+
type=Path,
46+
help="Path to a noisy audio file OR a directory of audio files.",
47+
)
48+
args = parser.parse_args()
49+
50+
run_dir = args.run_dir.expanduser().resolve()
51+
if not run_dir.exists():
52+
sys.exit(f"--run_dir '{run_dir}' does not exist.")
53+
54+
hparams_file = run_dir / "hyperparams.yaml"
55+
checkpoints_dir = run_dir / "checkpoints"
56+
57+
with open(hparams_file, encoding="utf-8") as f:
58+
hparams = load_hyperpyyaml(f)
59+
60+
target_sr = hparams["sample_rate"]
61+
inference_dir = Path(run_dir / "enhanced_inference")
62+
inference_dir.mkdir(parents=True, exist_ok=True)
63+
64+
modules = hparams["modules"]
65+
brain = SGMSEBrain(
66+
modules=modules,
67+
hparams=hparams,
68+
run_opts={"device": "cuda" if torch.cuda.is_available() else "cpu"},
69+
checkpointer=Checkpointer(
70+
checkpoints_dir=checkpoints_dir,
71+
recoverables={"score_model": modules["score_model"]},
72+
),
73+
)
74+
brain.setup_inference() # loads latest checkpoint, ema ...
75+
76+
# Enhancement routine
77+
def enhance_file(noisy_path, dst_dir):
78+
wav, sr = torchaudio.load(noisy_path)
79+
if sr != target_sr:
80+
wav = torchaudio.functional.resample(wav, sr, target_sr)
81+
82+
if wav.shape[0] > 1:
83+
wav = wav.mean(0, keepdim=True)
84+
85+
with torch.no_grad():
86+
wav = wav.to(brain.device)
87+
enhanced = brain.enhance(wav).cpu()
88+
89+
out_path = dst_dir / f"{noisy_path.stem}_enhanced{noisy_path.suffix}"
90+
torchaudio.save(out_path.as_posix(), enhanced, target_sr, format="wav")
91+
return out_path
92+
93+
src = args.input.expanduser().resolve()
94+
95+
if src.is_file():
96+
if not is_audio_file(src):
97+
sys.exit(f"{src} is not a supported audio file.")
98+
out_path = enhance_file(src, inference_dir)
99+
print(f"Enhanced file written to {out_path}")
100+
101+
elif src.is_dir():
102+
files = collect_audio_files(src)
103+
if not files:
104+
sys.exit(f"{src} contains no enhanceable audio files.")
105+
106+
batch_out_dir = inference_dir / f"{src.name}_enhanced"
107+
batch_out_dir.mkdir(parents=True, exist_ok=True)
108+
109+
print(f"Enhancing {len(files)} file(s) > {batch_out_dir}")
110+
for idx, fpath in enumerate(files, 1):
111+
out_path = enhance_file(fpath, batch_out_dir)
112+
print(f"[{idx}/{len(files)}] > {out_path}")
113+
else:
114+
sys.exit(f"{src} is neither a file nor a directory.")
115+
116+
117+
if __name__ == "__main__":
118+
main()
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
gdown
2+
h5py
3+
hyperpyyaml
4+
ipympl
5+
librosa
6+
ninja
7+
numpy<2.0
8+
pandas
9+
pesq
10+
pillow
11+
protobuf
12+
pyarrow
13+
pyroomacoustics
14+
pystoi
15+
scipy
16+
sdeint
17+
seaborn
18+
setuptools
19+
git+https://github.com/sp-uhh/sgmse.git@main#egg=sgmse
20+
tensorboard
21+
torch
22+
torch-ema
23+
torch-pesq
24+
torchaudio
25+
torchinfo
26+
torchmetrics
27+
torchsde
28+
torchvision
29+
tqdm
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
output_folder: results # Main directory to store experiment results
2+
run_name: "RUN_NAME" # Will be updated with a unique name at runtime
3+
4+
save_dir: !ref <output_folder>/<run_name>/checkpoints # Directory to save checkpoints
5+
enhanced_dir: !ref <output_folder>/<run_name>/enhanced_training # Directory to store waveforms at validation during training
6+
7+
data_dir: !PLACEHOLDER # Root dir for the dataset
8+
train_annotation: !ref <data_dir>/train.json # JSON file listing training samples
9+
valid_annotation: !ref <data_dir>/valid.json # JSON file listing validation samples
10+
test_annotation: !ref <data_dir>/test.json # JSON file listing test samples
11+
12+
skip_prep: False # If True, skip data preparation steps
13+
segment_frames: 256 # Number of STFT frames fed into the model. Has to align with what the model ‘wants’ to see due to u net architecture
14+
random_crop: True # Whether to crop segments randomly from longer waveforms in training
15+
random_crop_valid: False # Whether to crop segments randomly from longer waveforms in validation
16+
random_crop_test: False # Whether to crop segments randomly from longer waveforms in testing
17+
18+
normalize: noisy # Waveforms are normalized with respect to ... (noisy / clean / not)
19+
sample_rate: 16000 # Sampling rate (in Hz) for audio data
20+
batch_size: 8 # Batch size for the training set
21+
number_of_epochs: 160 # Total epochs to train
22+
num_to_keep: 2 # Numbers of checkpoints to keep
23+
lr: 0.0001 # Learning rate
24+
sorting: ascending # Sorting strategy for data loading (e.g., ascending, descending)
25+
26+
n_fft: 510 # FFT size for STFT
27+
hop_length: 128 # Hop length (stride) for STFT
28+
window_type: hann # Type of window function for STFT
29+
30+
transform_type: exponent # Type of spectral transform (log, exponent, none)
31+
spec_factor: 0.15 # Factor to scale the transformed spectrogram
32+
spec_abs_exponent: 0.5 # Exponent to apply to spectrogram magnitude if needed
33+
34+
train_dataloader_opts:
35+
batch_size: !ref <batch_size>
36+
shuffle: True # Shuffle training data each epoch
37+
38+
valid_dataloader_opts:
39+
batch_size: 1 # Validation batch size
40+
41+
test_dataloader_opts:
42+
batch_size: 1 # Test batch size
43+
44+
sampling:
45+
sampler_type: pc
46+
predictor: reverse_diffusion
47+
corrector: ald
48+
N: 30
49+
corrector_steps: 1
50+
snr: 0.5
51+
52+
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
53+
limit: !ref <number_of_epochs> # Sets the upper bound on training epochs
54+
55+
modules:
56+
score_model: !new:speechbrain.integrations.models.sgmse_plus.ScoreModel
57+
backbone: ncsnpp_v2 # Name of the backbone neural network architecture
58+
sde: ouve # Which SDE to use (Ornstein-Uhlenbeck VE SDE)
59+
theta: 1.5 # Stiffness parameter for the OU SDE
60+
sigma_min: 0.05 # Minimum sigma value for OU SDE
61+
sigma_max: 0.5 # Maximum sigma value for OU SDE
62+
lr: !ref <lr> # Learning rate for the model
63+
ema_decay: 0.999 # Decay factor for EMA of model parameters
64+
t_eps: 0.03 # Min time-step to avoid zero in continuous diffusion
65+
num_eval_files: 5 # Number of files to process for evaluation
66+
loss_type: score_matching # Which loss approach to use (score matching, etc.)
67+
loss_weighting: sigma^2 # Weighting in the loss function
68+
network_scaling: 1/t # Scaling strategy (if any) for network outputs
69+
c_in: "1" # Input scaling scheme
70+
c_out: "1" # Output scaling scheme
71+
c_skip: "0" # Skip connection scaling scheme
72+
sigma_data: 0.1 # Data STD for EDM-based parameterizations
73+
l1_weight: 0.001 # Weight factor for L1 (time-domain) loss
74+
pesq_weight: 0.0 # Weight factor for PESQ-based loss (0 = disabled)
75+
N: !ref <sampling[N]> # Sampler steps
76+
corrector_steps: !ref <sampling[corrector_steps]> # Corrector updates per step
77+
sampler_type: !ref <sampling[sampler_type]> # SDE sampler type
78+
snr: !ref <sampling[snr]> # SNR for sampler
79+
sr: !ref <sample_rate> # Sample rate for model references
80+
81+
opt_class: !name:torch.optim.Adam
82+
lr: !ref <lr> # LR used in the Adam optimizer
83+
84+
checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
85+
checkpoints_dir: !ref <save_dir> # Directory to store checkpoint files
86+
recoverables:
87+
score_model: !ref <modules[score_model]> # Model parameters to be saved
88+
counter: !ref <epoch_counter> # Epoch counter to be saved

0 commit comments

Comments
 (0)