-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Fix Conformer Instabilities and add Large Model #1892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
b54309d
73071f1
1d4ee1b
18f97dd
013b5f5
c6d44af
b50e00b
e3302a9
2ee5eca
c43dbd2
580d8f7
f285e42
7353651
8485f11
196694b
89da49f
d051a38
f775b4f
b347ff7
1140628
216a74d
979ad16
6bd3fec
09f4f69
bece594
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,290 @@ | ||
| # ############################################################################ | ||
| # Model: E2E ASR with Transformer | ||
| # Encoder: Conformer Encoder | ||
| # Decoder: Transformer Decoder + (CTC/ATT joint) beamsearch + TransformerLM | ||
| # Tokens: unigram | ||
| # losses: CTC + KLdiv (Label Smoothing loss) | ||
| # Training: Librispeech 960h | ||
| # Authors: Jianyuan Zhong, Titouan Parcollet, Samuele Cornell | ||
| # ############################################################################ | ||
| # Seed needs to be set at top of yaml, before objects with parameters are made | ||
|
|
||
| seed: 7775 | ||
| __set_seed: !apply:torch.manual_seed [!ref <seed>] | ||
| output_folder: !ref results/conformer_small/<seed> | ||
| wer_file: !ref <output_folder>/wer.txt | ||
| save_folder: !ref <output_folder>/save | ||
| train_log: !ref <output_folder>/train_log.txt | ||
|
|
||
| # Language model (LM) pretraining | ||
| # NB: To avoid mismatch, the speech recognizer must be trained with the same | ||
| # tokenizer used for LM training. Here, we download everything from the | ||
| # speechbrain HuggingFace repository. However, a local path pointing to a | ||
| # directory containing the lm.ckpt and tokenizer.ckpt may also be specified | ||
| # instead. E.g if you want to use your own LM / tokenizer. | ||
| pretrained_lm_tokenizer_path: speechbrain/asr-transformer-transformerlm-librispeech | ||
|
|
||
| # Data files | ||
| data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech | ||
| # If RIRS_NOISES dir exists in /localscratch/xxx_corpus/RIRS_NOISES | ||
| # then data_folder_rirs should be /localscratch/xxx_corpus | ||
| # otherwise the dataset will automatically be downloaded | ||
| # data_folder_rirs: !ref <data_folder> | ||
| train_splits: ["train-clean-100", "train-clean-360", "train-other-500"] | ||
| dev_splits: ["dev-clean"] | ||
| test_splits: ["test-clean", "test-other"] | ||
| skip_prep: False | ||
| train_csv: !ref <output_folder>/train.csv | ||
| valid_csv: !ref <output_folder>/dev-clean.csv | ||
| test_csv: | ||
| - !ref <output_folder>/test-clean.csv | ||
| - !ref <output_folder>/test-other.csv | ||
|
|
||
| # Training parameters | ||
| # To make Transformers converge, the global bath size should be large enough. | ||
| # The global batch size is computed as batch_size * n_gpus * gradient_accumulation. | ||
| # Empirically, we found that this value should be >= 128. | ||
| # Please, set your parameters accordingly. | ||
| number_of_epochs: 110 | ||
| batch_size: 16 # This works for 2x GPUs with 32GB | ||
| ctc_weight: 0.3 | ||
| grad_accumulation_factor: 1 | ||
| max_grad_norm: 5.0 | ||
| loss_reduction: 'batchmean' | ||
| sorting: random | ||
| num_workers: 4 | ||
|
|
||
| # stages related parameters | ||
| # stage_one_epochs: 90 | ||
| lr_adam: 0.0005 | ||
| # lr_sgd: 0.000025 | ||
|
|
||
| # Feature parameters | ||
| sample_rate: 16000 | ||
| n_fft: 400 | ||
| n_mels: 80 | ||
|
|
||
| # This setup works well for V100 32GB GPU, adapts it to your needs. | ||
| # Or turn it off (but training speed will decrease) | ||
| dynamic_batching: True | ||
| max_batch_len: 900 | ||
| max_batch_len_val: 100 # we reduce it as the beam is much wider (VRAM) | ||
| num_bucket: 200 | ||
|
|
||
| dynamic_batch_sampler: | ||
| max_batch_len: !ref <max_batch_len> | ||
| max_batch_len_val: !ref <max_batch_len_val> | ||
| num_buckets: !ref <num_bucket> | ||
| shuffle_ex: True # if true re-creates batches at each epoch shuffling examples. | ||
| batch_ordering: random | ||
| max_batch_ex: 128 | ||
|
|
||
| # Dataloader options | ||
| train_dataloader_opts: | ||
| batch_size: !ref <batch_size> | ||
| shuffle: True | ||
| num_workers: !ref <num_workers> | ||
|
|
||
| valid_dataloader_opts: | ||
| batch_size: 1 | ||
|
|
||
| test_dataloader_opts: | ||
| batch_size: 1 | ||
|
|
||
| ####################### Model parameters ########################### | ||
| # Transformer | ||
| d_model: 512 | ||
| nhead: 8 | ||
| num_encoder_layers: 12 | ||
| num_decoder_layers: 6 | ||
| d_ffn: 2048 | ||
| transformer_dropout: 0.1 | ||
| activation: !name:torch.nn.GELU | ||
| output_neurons: 5000 | ||
|
|
||
| # Outputs | ||
| blank_index: 0 | ||
| label_smoothing: 0.0 | ||
| pad_index: 0 | ||
| bos_index: 1 | ||
| eos_index: 2 | ||
|
|
||
| # Decoding parameters | ||
| min_decode_ratio: 0.0 | ||
| max_decode_ratio: 1.0 | ||
| valid_search_interval: 10 | ||
| valid_beam_size: 10 | ||
| test_beam_size: 66 | ||
| lm_weight: 0.60 | ||
| ctc_weight_decode: 0.40 | ||
|
|
||
| ############################## models ################################ | ||
|
|
||
| CNN: !new:speechbrain.lobes.models.convolution.ConvolutionFrontEnd | ||
| input_shape: (8, 10, 80) | ||
| num_blocks: 2 | ||
| num_layers_per_block: 1 | ||
| out_channels: (64, 32) | ||
| kernel_sizes: (3, 3) | ||
| strides: (2, 2) | ||
| residuals: (False, False) | ||
|
|
||
| Transformer: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR # yamllint disable-line rule:line-length | ||
| input_size: 640 | ||
| tgt_vocab: !ref <output_neurons> | ||
| d_model: !ref <d_model> | ||
| nhead: !ref <nhead> | ||
| num_encoder_layers: !ref <num_encoder_layers> | ||
| num_decoder_layers: !ref <num_decoder_layers> | ||
| d_ffn: !ref <d_ffn> | ||
| dropout: !ref <transformer_dropout> | ||
| activation: !ref <activation> | ||
| encoder_module: conformer | ||
| attention_type: RelPosMHAXL | ||
| normalize_before: True | ||
| causal: False | ||
|
|
||
| # This is the TransformerLM that is used according to the Huggingface repository | ||
| # Visit the HuggingFace model corresponding to the pretrained_lm_tokenizer_path | ||
| # For more details about the model! | ||
| # NB: It has to match the pre-trained TransformerLM!! | ||
| lm_model: !new:speechbrain.lobes.models.transformer.TransformerLM.TransformerLM # yamllint disable-line rule:line-length | ||
| vocab: !ref <output_neurons> | ||
| d_model: 768 | ||
| nhead: 12 | ||
| num_encoder_layers: 12 | ||
| num_decoder_layers: 0 | ||
| d_ffn: 3072 | ||
| dropout: 0.0 | ||
| activation: !name:torch.nn.GELU | ||
| normalize_before: False | ||
|
|
||
| tokenizer: !new:sentencepiece.SentencePieceProcessor | ||
|
|
||
| ctc_lin: !new:speechbrain.nnet.linear.Linear | ||
| input_size: !ref <d_model> | ||
| n_neurons: !ref <output_neurons> | ||
|
|
||
| seq_lin: !new:speechbrain.nnet.linear.Linear | ||
| input_size: !ref <d_model> | ||
| n_neurons: !ref <output_neurons> | ||
|
|
||
| normalize: !new:speechbrain.processing.features.InputNormalization | ||
| norm_type: global | ||
| update_until_epoch: 4 | ||
|
|
||
| modules: | ||
| CNN: !ref <CNN> | ||
| Transformer: !ref <Transformer> | ||
| seq_lin: !ref <seq_lin> | ||
| ctc_lin: !ref <ctc_lin> | ||
| normalize: !ref <normalize> | ||
|
|
||
| model: !new:torch.nn.ModuleList | ||
| - [!ref <CNN>, !ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>] | ||
|
|
||
| # define two optimizers here for two-stage training | ||
| Adam: !name:torch.optim.Adam | ||
| lr: !ref <lr_adam> | ||
| betas: (0.9, 0.98) | ||
| eps: 0.000000001 | ||
|
|
||
| #SGD: !name:torch.optim.SGD | ||
| # lr: !ref <lr_sgd> | ||
| # momentum: 0.99 | ||
| # nesterov: True | ||
|
|
||
| valid_search: !new:speechbrain.decoders.S2STransformerBeamSearch | ||
| modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>] | ||
| bos_index: !ref <bos_index> | ||
| eos_index: !ref <eos_index> | ||
| blank_index: !ref <blank_index> | ||
| min_decode_ratio: !ref <min_decode_ratio> | ||
| max_decode_ratio: !ref <max_decode_ratio> | ||
| beam_size: !ref <valid_beam_size> | ||
| ctc_weight: !ref <ctc_weight_decode> | ||
| using_eos_threshold: False | ||
| length_normalization: False | ||
|
|
||
|
|
||
| test_search: !new:speechbrain.decoders.S2STransformerBeamSearch | ||
| modules: [!ref <Transformer>, !ref <seq_lin>, !ref <ctc_lin>] | ||
| bos_index: !ref <bos_index> | ||
| eos_index: !ref <eos_index> | ||
| blank_index: !ref <blank_index> | ||
| min_decode_ratio: !ref <min_decode_ratio> | ||
| max_decode_ratio: !ref <max_decode_ratio> | ||
| beam_size: !ref <test_beam_size> | ||
| ctc_weight: !ref <ctc_weight_decode> | ||
| lm_weight: !ref <lm_weight> | ||
| lm_modules: !ref <lm_model> | ||
| temperature: 1.15 | ||
| temperature_lm: 1.15 | ||
| using_eos_threshold: False | ||
| length_normalization: True | ||
|
|
||
| log_softmax: !new:torch.nn.LogSoftmax | ||
| dim: -1 | ||
|
|
||
| ctc_cost: !name:speechbrain.nnet.losses.ctc_loss | ||
| blank_index: !ref <blank_index> | ||
| reduction: !ref <loss_reduction> | ||
|
|
||
| seq_cost: !name:speechbrain.nnet.losses.kldiv_loss | ||
| label_smoothing: !ref <label_smoothing> | ||
| reduction: !ref <loss_reduction> | ||
|
|
||
| noam_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler | ||
| lr_initial: !ref <lr_adam> | ||
| n_warmup_steps: 25000 | ||
|
|
||
| checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer | ||
| checkpoints_dir: !ref <save_folder> | ||
| recoverables: | ||
| model: !ref <model> | ||
| noam_scheduler: !ref <noam_annealing> | ||
| normalizer: !ref <normalize> | ||
| counter: !ref <epoch_counter> | ||
|
|
||
| epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter | ||
| limit: !ref <number_of_epochs> | ||
|
|
||
| augmentation: !new:speechbrain.lobes.augment.SpecAugment | ||
| time_warp: True | ||
| time_warp_window: 5 | ||
| time_warp_mode: bicubic | ||
| freq_mask: True | ||
| n_freq_mask: 2 | ||
| time_mask: True | ||
| n_time_mask: 2 | ||
| replace_with_zero: False | ||
| freq_mask_width: 30 | ||
| time_mask_width: 40 | ||
|
|
||
| speed_perturb: !new:speechbrain.processing.speech_augmentation.SpeedPerturb | ||
| orig_freq: !ref <sample_rate> | ||
| speeds: [95, 100, 105] | ||
|
|
||
| compute_features: !new:speechbrain.lobes.features.Fbank | ||
| sample_rate: !ref <sample_rate> | ||
| n_fft: !ref <n_fft> | ||
| n_mels: !ref <n_mels> | ||
|
|
||
| train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger | ||
| save_file: !ref <train_log> | ||
|
|
||
| error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats | ||
| acc_computer: !name:speechbrain.utils.Accuracy.AccuracyStats | ||
|
|
||
| # The pretrainer allows a mapping between pretrained files and instances that | ||
| # are declared in the yaml. E.g here, we will download the file lm.ckpt | ||
| # and it will be loaded into "lm" which is pointing to the <lm_model> defined | ||
| # before. | ||
| pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer | ||
| collect_in: !ref <save_folder> | ||
| loadables: | ||
| lm: !ref <lm_model> | ||
| tokenizer: !ref <tokenizer> | ||
| paths: | ||
| lm: !ref <pretrained_lm_tokenizer_path>/lm.ckpt | ||
| tokenizer: !ref <pretrained_lm_tokenizer_path>/tokenizer.ckpt |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -242,6 +242,12 @@ def parse_arguments(arg_list=None): | |
| action="store_true", | ||
| help="This flag enables training with automatic mixed-precision.", | ||
| ) | ||
| parser.add_argument( | ||
| "--bfloat16_mix_prec", | ||
| default=None, | ||
| action="store_true", | ||
| help="This flag enables training with bfloat16 mixed-precision.", | ||
| ) | ||
| parser.add_argument( | ||
| "--max_grad_norm", | ||
| type=float, | ||
|
|
@@ -465,6 +471,7 @@ def __init__( # noqa: C901 | |
| "find_unused_parameters": False, | ||
| "jit_module_keys": None, | ||
| "auto_mix_prec": False, | ||
| "bfloat16_mix_prec": False, | ||
| "max_grad_norm": 5.0, | ||
| "nonfinite_patience": 3, | ||
| "noprogressbar": False, | ||
|
|
@@ -915,7 +922,9 @@ def fit_batch(self, batch): | |
| if self.auto_mix_prec: | ||
| with torch.cuda.amp.autocast(): | ||
| outputs = self.compute_forward(batch, Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
|
|
||
| # Losses are excluded from mixed precision to avoid instabilities | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
| with self.no_sync(not should_step): | ||
| self.scaler.scale( | ||
| loss / self.grad_accumulation_factor | ||
|
|
@@ -928,8 +937,13 @@ def fit_batch(self, batch): | |
| self.zero_grad() | ||
| self.optimizer_step += 1 | ||
| else: | ||
| outputs = self.compute_forward(batch, Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
| if self.bfloat16_mix_prec: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think something like this is a good idea. This is not a necessary change, but a slightly more succinct way to write this would be something like:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah true, I'll fix this.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | ||
| outputs = self.compute_forward(batch, Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
| else: | ||
| outputs = self.compute_forward(batch, Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
| with self.no_sync(not should_step): | ||
| (loss / self.grad_accumulation_factor).backward() | ||
| if should_step: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are other recipes, too
Is there another way of handling this?
The mixed precision structure/logic is indepdent of the actual steps taken.
How about putting everything below the logic steps into different core functions that can be inherited & overwritten?
does not change, regardless of float precision. I get the copy/paste is easy but with more nested logics that
with ...:wrap the same code blocks ... it's prone to making mistakes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will affect the many other files, in one way or another. Bc of that, putting most of this logic into the core module would be more practical (thinking about future maintenance running into this again).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipes that have overloaded the fit_batch won't benefit from this change. It is expected, and I am not going to refactor this. It is a deliberate design choice to override the fit_batch function... And unfortunately, it is impossible to wrap this mixed precision stuff outside of fit batch because it is linked to the forward / backward passes. Also some operation are not stable with mixed precision, so it must be changed by the user in the fit_batch if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm suggesting to split the fit_batch up further for more detailed handling of this logic, since this is clearly a pytorch wrapping issue. It will re-appear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases the newly added on_fit_batch_end can address this, see discussion at #1864 (comment)