SGMSE Voicebank Speech Enhancement Recipe#2947
SGMSE Voicebank Speech Enhancement Recipe#2947pplantinga merged 30 commits intospeechbrain:developfrom
Conversation
This reverts commit 397b012.
pplantinga
left a comment
There was a problem hiding this comment.
Everything looks quite good, I was able to run the code without any issues. I have only a few minor comments about sections that might work slightly better (e.g. shorter "train.py" file) if they followed the more SpeechBrain-idiomatic way of doing things, like using the speechbrain.processing.features.STFT class or something similar instead of handling it all in the train file.
The only remaining pieces are:
- Add results to dropbox and record the link and scores achieved to the README
- it would be very nice if we are able to host a model on Huggingface and run inference using SpeechBrain code. Probably the easiest method would be to add an inference class to the
sgmse_plus.pyfile that enhances a waveform when called, and can be used withspeechbrain.inference.enhancement.EnhanceWaveform
| @@ -0,0 +1,90 @@ | |||
| output_folder: /export/home/1rochdi/speechbrain/results/Voicebank/enhance/SGMSE # Main directory to store experiment results | |||
There was a problem hiding this comment.
This could just be "results" perhaps, this path won't exist on most systems.
| #inference_dir: !ref <output_folder>/<run_name>/enhanced_inference # Directory to store waveforms at inference | ||
| tensorboard_logs: !ref <output_folder>/tensorboard_logs/ # Directory for TensorBoard logs | ||
|
|
||
| data_dir: /data/datasets/noisy-vctk-16k # Root dir for the dataset |
There was a problem hiding this comment.
Could be !PLACEHOLDER so people know to change it.
| # STFT | ||
| n_fft = self.hparams.n_fft | ||
| hop_length = self.hparams.hop_length | ||
| window_type = self.hparams.window_type | ||
| self.window = self.get_window(window_type, n_fft).to(self.device) | ||
| self.stft_kwargs = { | ||
| "n_fft": n_fft, | ||
| "hop_length": hop_length, | ||
| "center": True, | ||
| "return_complex": True, | ||
| } |
There was a problem hiding this comment.
Typically in Speechbrain we just define the STFT in the hparams
| ema = self.modules["score_model"].ema | ||
| self.checkpointer.add_recoverable( | ||
| name="ema", | ||
| obj=ema, | ||
| custom_save_hook=lambda obj, path: torch.save( | ||
| obj.state_dict(), path | ||
| ), | ||
| custom_load_hook=lambda obj, path, end: obj.load_state_dict( | ||
| torch.load(path) | ||
| ), | ||
| ) |
There was a problem hiding this comment.
This is fine, but you could also add the save/load code to the model itself with @mark_as_saver and @mark_as_loader
| cli.add_argument( | ||
| "--resume", | ||
| type=str, | ||
| default="", | ||
| help="Path to an existing run directory to resume.", | ||
| ) | ||
| resume_args, remaining = cli.parse_known_args() | ||
|
|
||
| hparams_file, run_opts, overrides = sb.parse_arguments(remaining) | ||
|
|
||
| if resume_args.resume: # Resume | ||
| run_dir = Path(resume_args.resume).resolve() | ||
| hparams_file = run_dir / "hyperparams.yaml" | ||
| overrides = overrides or "" | ||
| else: # New | ||
| run_name = f"run_{datetime.now():%Y-%m-%d_%H-%M-%S}" | ||
| overrides = (overrides or "") + f"\nrun_name: '{run_name}'" |
There was a problem hiding this comment.
I like the resume mechanism here, nice work! I did get tripped up once while using it (the hparams in the results dir did not have a reference I expected to be there) but I don't think there's anything to fix here necessarily. Maybe a message could be nice that states where the hparams are loaded from.
|
The reason the |
What does this PR do?
recipes/Voicebank/enhance/SGMSE/for SGMSE Voicebank enhancementtrain.py(adapted Brain class and training loop)hparams.yaml(hyperparameters for training)enhance.pyinference script to generate enhanced audio on demandextra_requirements.txtrequirements file to install dependencies required for this recipeBefore submitting
PR review
Reviewer checklist