Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ The goal is to create a **single**, **flexible**, and **user-friendly** toolkit

| **[Tutorials](https://speechbrain.github.io/tutorial_basics.html)** | **[Website](https://speechbrain.github.io/)** | **[Documentation](https://speechbrain.readthedocs.io/en/latest/index.html)** | **[Contributing](https://speechbrain.readthedocs.io/en/latest/contributing.html)** | **[HuggingFace](https://huggingface.co/speechbrain)** |

# PyTorch 2.0 considerations

In March 2023, PyTorch introduced a new version, PyTorch 2.0, which offers numerous enhancements to the community. At present, the majority of SpeechBrain is compatible with PyTorch 2.0. However, certain sections of the code remain incompatible, and we are actively working towards full compatibility with PyTorch 2.0. For the time being, we recommend users continue utilizing PyTorch 1.13, as this is the version employed in our experiments.

If you wish to use SpeechBrain alongside PyTorch 2.0 and encounter any issues, kindly inform us by responding to this [issue](https://github.com/speechbrain/speechbrain/issues/1897).

# Key features

SpeechBrain provides various useful tools to speed up and facilitate research on speech and language technologies:
Expand All @@ -36,10 +42,11 @@ SpeechBrain supports state-of-the-art methods for end-to-end speech recognition:
- State-of-the-art performance or comparable with other existing toolkits in several ASR benchmarks.
- Easily customizable neural language models, including RNNLM and TransformerLM. We also share several pre-trained models that you can easily use (more to come!). We support the Hugging Face `dataset` to facilitate the training over a large text dataset.
- Hybrid CTC/Attention end-to-end ASR:
- Many available encoders: CRDNN (VGG + {LSTM,GRU,LiGRU} + DNN), ResNet, SincNet, vanilla transformers, context net-based transformers or conformers. Thanks to the flexibility of SpeechBrain, any fully customized encoder could be connected to the CTC/attention decoder and trained in a few hours of work. The decoder is fully customizable: LSTM, GRU, LiGRU, transformer, or your neural network!
- Many available encoders: CRDNN (VGG + {LSTM,GRU,Li-GRU} + DNN), ResNet, SincNet, vanilla transformers, whisper, context net-based transformers or conformers. Thanks to the flexibility of SpeechBrain, any fully customized encoder could be connected to the CTC/attention decoder and trained in a few hours of work. The decoder is fully customizable: LSTM, GRU, LiGRU, transformer, or your neural network!
- Optimised and fast beam search on both CPUs and GPUs.
- Transducer end-to-end ASR with both a custom Numba loss and the torchaudio one. Any encoder or decoder can be plugged into the transducer ranging from VGG+RNN+DNN to conformers.
- Pre-trained ASR models for transcribing an audio file or extracting features for a downstream task.
- Fully customizable with the possibility to add external Beam Search decoders, if the ones offered natively by SpeechBrain are not sufficient, such as [PyCTCDecode](https://github.com/kensho-technologies/pyctcdecode) like in our LibriSpeech CTC wav2vec recipe.

### Feature extraction and augmentation

Expand Down Expand Up @@ -69,7 +76,7 @@ SpeechBrain provides different models for speaker recognition, identification, a
### Grapheme-to-Phoneme (G2P)
We have models for converting characters into a sequence of phonemes. In particular, we have Transformer- and RNN-based models operating at the sentence level (i.e, converting a full sentence into a corresponding sequence of phonemes). The models are trained with both data from Wikipedia and LibriSpeech.

### Language Identification
### Language Identification
SpeechBrain provides different models for language identification.
In particular, our best model is based on an ECAPA-TDNN trained with the [voxlingua107 dataset](http://bark.phon.ioc.ee/voxlingua107/).

Expand All @@ -85,15 +92,22 @@ Combining multiple microphones is a powerful approach to achieving robustness in
- Delay-and-sum, MVDR, and GeV beamforming.
- Speaker localization.

### Emotion Recognition
- Recipes for emotion recognition using SSL and ECAPA-TDNN models.

### Interpretability
- Recipes for various intepretability techniques on the ESC50 dataset.

### Spoken Language Understanding
- Recipes for training wav2vec 2.0 models with the [MEDIA](https://catalogue.elra.info/en-us/repository/browse/ELRA-E0024/) dataset.

### Performance
The recipes released with speechbrain implement speech processing systems with competitive or state-of-the-art performance. In the following, we report the best performance achieved on some popular benchmarks:

| Dataset | Task | System | Performance |
| ------------- |:-------------:| -----:|-----:|
| LibriSpeech | Speech Recognition | wav2vec2 | WER=1.90% (test-clean) |
| LibriSpeech | Speech Recognition | CNN + Transformer | WER=2.26% (test-clean) |
| LibriSpeech | Speech Recognition | CNN + Conformer | WER=2.2% (test-clean) |
| TIMIT | Speech Recognition | CRDNN + distillation | PER=13.1% (test) |
| TIMIT | Speech Recognition | wav2vec2 + CTC/Att. | PER=8.04% (test) |
| CommonVoice (English) | Speech Recognition | wav2vec2 + CTC | WER=15.69% (test) |
Expand Down
4 changes: 2 additions & 2 deletions recipes/VoxLingua107/lang_id/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def on_stage_end(self, stage, stage_loss, epoch=None):
def dataio_prep_shards(hparams):

# load the meta info json file
with wds.gopen.gopen(hparams["train_meta"], "rb") as f:
with wds.gopen(hparams["train_meta"], "rb") as f:
train_meta = json.load(f)
with wds.gopen.gopen(hparams["val_meta"], "rb") as f:
with wds.gopen(hparams["val_meta"], "rb") as f:
val_meta = json.load(f)

# define the mapping functions in the data pipeline
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pre-commit>=2.3.0
scipy>=1.4.1, <1.9
sentencepiece>=0.1.91
SoundFile; sys_platform == 'win32'
torch>=1.9.0
torchaudio>=0.9.0
torch>=1.9.0,<2.0
torchaudio>=0.9.0,<2.0
tqdm>=4.42.0
46 changes: 32 additions & 14 deletions speechbrain/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import shutil
import logging
import warnings
from packaging import version
import speechbrain.utils._workarounds as __wa

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -158,20 +159,37 @@ def torch_parameter_transfer(obj, path, device):


# These dicts are indexed by class and hold the default checkpoints methods
DEFAULT_LOAD_HOOKS = {
torch.nn.Module: torch_recovery,
torch.optim.Optimizer: torch_recovery,
torch.optim.lr_scheduler._LRScheduler: torch_recovery,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery,
torch.cuda.amp.grad_scaler.GradScaler: torch_recovery,
}
DEFAULT_SAVE_HOOKS = {
torch.nn.Module: torch_save,
torch.optim.Optimizer: torch_save,
torch.optim.lr_scheduler._LRScheduler: torch_save,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save,
torch.cuda.amp.grad_scaler.GradScaler: torch_save,
}
if version.parse(torch.__version__) < version.parse("2.0.0"):
DEFAULT_LOAD_HOOKS = {
torch.nn.Module: torch_recovery,
torch.optim.Optimizer: torch_recovery,
torch.optim.lr_scheduler._LRScheduler: torch_recovery,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery,
torch.cuda.amp.grad_scaler.GradScaler: torch_recovery,
}
DEFAULT_SAVE_HOOKS = {
torch.nn.Module: torch_save,
torch.optim.Optimizer: torch_save,
torch.optim.lr_scheduler._LRScheduler: torch_save,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save,
torch.cuda.amp.grad_scaler.GradScaler: torch_save,
}
else:
DEFAULT_LOAD_HOOKS = {
torch.nn.Module: torch_recovery,
torch.optim.Optimizer: torch_recovery,
torch.optim.lr_scheduler.LRScheduler: torch_recovery,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_recovery,
torch.cuda.amp.grad_scaler.GradScaler: torch_recovery,
}
DEFAULT_SAVE_HOOKS = {
torch.nn.Module: torch_save,
torch.optim.Optimizer: torch_save,
torch.optim.lr_scheduler.LRScheduler: torch_save,
torch.optim.lr_scheduler.ReduceLROnPlateau: torch_save,
torch.cuda.amp.grad_scaler.GradScaler: torch_save,
}

DEFAULT_TRANSFER_HOOKS = {
torch.nn.Module: torch_parameter_transfer,
}
Expand Down