Fix Conformer Instabilities and add Large Model#1892
Fix Conformer Instabilities and add Large Model#1892TParcollet merged 25 commits intospeechbrain:developfrom
Conversation
…into fix_conformer
speechbrain/core.py
Outdated
| else: | ||
| outputs = self.compute_forward(batch, Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, Stage.TRAIN) | ||
| if self.bfloat16_mix_prec: |
There was a problem hiding this comment.
I think something like this is a good idea. This is not a necessary change, but a slightly more succinct way to write this would be something like:
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=self.bfloat16_mix_prec):
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
There was a problem hiding this comment.
Yeah true, I'll fix this.
|
Did you still use the relative positional encoding attention ? |
| encoder_out = encoder_out + self.positional_encoding_decoder( | ||
| encoder_out | ||
| ) | ||
| # encoder_out = encoder_out + self.positional_encoding_decoder( |
There was a problem hiding this comment.
probably we could just remove it.
|
Yes I don't know honestly, I just saw this as the only difference between Transformer and Conformer, removed it, tried with a bigger model and it was stable... |
| ) | ||
| else: | ||
| outputs = self.compute_forward(batch, sb.Stage.TRAIN) | ||
| loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN) |
There was a problem hiding this comment.
There are other recipes, too
grep -rn auto_mix_prec recipes | grep py
Is there another way of handling this?
The mixed precision structure/logic is indepdent of the actual steps taken.
How about putting everything below the logic steps into different core functions that can be inherited & overwritten?
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(
outputs, batch, sb.Stage.TRAIN
)does not change, regardless of float precision. I get the copy/paste is easy but with more nested logics that with ...: wrap the same code blocks ... it's prone to making mistakes.
There was a problem hiding this comment.
It will affect the many other files, in one way or another. Bc of that, putting most of this logic into the core module would be more practical (thinking about future maintenance running into this again).
There was a problem hiding this comment.
Recipes that have overloaded the fit_batch won't benefit from this change. It is expected, and I am not going to refactor this. It is a deliberate design choice to override the fit_batch function... And unfortunately, it is impossible to wrap this mixed precision stuff outside of fit batch because it is linked to the forward / backward passes. Also some operation are not stable with mixed precision, so it must be changed by the user in the fit_batch if needed.
There was a problem hiding this comment.
Yes, I'm suggesting to split the fit_batch up further for more detailed handling of this logic, since this is clearly a pytorch wrapping issue. It will re-appear.
There was a problem hiding this comment.
In some cases the newly added on_fit_batch_end can address this, see discussion at #1864 (comment)
| # encoder_out | ||
| # ) | ||
| pos_embs_encoder = None # self.positional_encoding(src) | ||
| pos_embs_target = None |
There was a problem hiding this comment.
positional_encoding_decoder is only set for attention_type == "RelPosMHAXL"
if attention_type == "RelPosMHAXL":
self.positional_encoding = RelPosEncXL(d_model)
self.positional_encoding_decoder = PositionalEncoding(
d_model, max_length
)So - what was this fixme about in the first place?
btw - it's also used in G2P:
grep -r positional_encoding_decoder speechbrain
What is this line about?
pos_embs_encoder = None # self.positional_encoding(src)
pos_embs_encoder = self.positional_encoding(src) # if we drop the other one, as you suggestedMaybe the comment there can be dropped (it would relate to the forward in class RelPosMHAXL(nn.Module)).
...
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
...
p_k = self.linear_pos(pos_embs).view(
1, -1, self.num_heads, self.head_dim
)|
At which level is the positional encoding meant to work? The paper https://arxiv.org/pdf/1901.02860.pdf xref:ed in SB readthedocs refers in "3.3 Relative Positional Encodings" to:
for which they add:
So, the crucial thing during training is to have the relative encoding in the target? Btw... got curious about the if/else construction: speechbrain/speechbrain/lobes/models/transformer/TransformerASR.py Lines 186 to 199 in 14b7122 Compare: speechbrain/speechbrain/lobes/models/transformer/Transformer.py Lines 124 to 136 in 14b7122 and during encode(): speechbrain/speechbrain/lobes/models/transformer/TransformerASR.py Lines 304 to 310 in 14b7122 Are there tutorials for how paper write-ups relate to code? |
|
@anautsch I cleaned this if else. The bug is just that we should not do that I guess :p |
We've been struggling with training larger conformers due to instability. This PR attempts at fixing this. In practice, for some reason, please @popcornell check, we were adding a sinusoidal pos emb to the encoder output before giving to the decoder. We don't do this for the standard transformer, so I removed it and trained a large model.
I also added the support of bfloat16 training in the core.py. This is a bit redundant with --auto_mix_prec though. @anautsch what do you think? The problem is that the best thing would be to rename --auto_mix_prec into --fp16_mix_prec but this would affect soooooooooooo many files ... So maybe keeping --auto_mix_prec and --bfloat16_mix_prec will do :-)? @pplantinga what do you think also?
PRE-FIX TRAIN_LOG
epoch: 1, lr: 3.59e-05, steps: 1797, optimizer: Adam - train loss: 3.06e+02 - valid loss: 2.09e+02, valid ACC: 1.83e-01
epoch: 2, lr: 7.19e-05, steps: 3594, optimizer: Adam - train loss: 2.54e+02 - valid loss: 1.69e+02, valid ACC: 2.58e-01
epoch: 3, lr: 1.08e-04, steps: 5391, optimizer: Adam - train loss: 1.66e+02 - valid loss: 66.64, valid ACC: 6.99e-01
epoch: 4, lr: 1.44e-04, steps: 7188, optimizer: Adam - train loss: 80.97 - valid loss: 33.55, valid ACC: 8.58e-01
epoch: 5, lr: 1.80e-04, steps: 8985, optimizer: Adam - train loss: 52.80 - valid loss: 23.88, valid ACC: 9.00e-01
epoch: 6, lr: 2.16e-04, steps: 10782, optimizer: Adam - train loss: 41.22 - valid loss: 19.67, valid ACC: 9.23e-01
epoch: 7, lr: 2.52e-04, steps: 12579, optimizer: Adam - train loss: 34.78 - valid loss: 16.37, valid ACC: 9.33e-01
epoch: 8, lr: 2.88e-04, steps: 14376, optimizer: Adam - train loss: 30.60 - valid loss: 15.25, valid ACC: 9.39e-01
epoch: 9, lr: 3.23e-04, steps: 16173, optimizer: Adam - train loss: 27.49 - valid loss: 13.32, valid ACC: 9.45e-01
epoch: 10, lr: 3.59e-04, steps: 17970, optimizer: Adam - train loss: 25.36 - valid loss: 12.12, valid ACC: 9.48e-01, valid WER: 6.76
epoch: 11, lr: 3.95e-04, steps: 19767, optimizer: Adam - train loss: 23.71 - valid loss: 11.84, valid ACC: 9.50e-01
epoch: 12, lr: 4.31e-04, steps: 21564, optimizer: Adam - train loss: 22.71 - valid loss: 10.87, valid ACC: 9.52e-01
epoch: 13, lr: 4.67e-04, steps: 23361, optimizer: Adam - train loss: 21.46 - valid loss: 10.01, valid ACC: 9.54e-01
epoch: 14, lr: 4.98e-04, steps: 25158, optimizer: Adam - train loss: 20.59 - valid loss: 10.16, valid ACC: 9.55e-01
epoch: 15, lr: 4.82e-04, steps: 26955, optimizer: Adam - train loss: 19.34 - valid loss: 9.60, valid ACC: 9.57e-01
epoch: 16, lr: 4.66e-04, steps: 28752, optimizer: Adam - train loss: 17.85 - valid loss: 8.99, valid ACC: 9.60e-01
epoch: 17, lr: 4.52e-04, steps: 30549, optimizer: Adam - train loss: 16.33 - valid loss: 8.85, valid ACC: 9.62e-01
epoch: 18, lr: 4.40e-04, steps: 32346, optimizer: Adam - train loss: 15.22 - valid loss: 8.10, valid ACC: 9.64e-01
epoch: 19, lr: 4.28e-04, steps: 34143, optimizer: Adam - train loss: 14.33 - valid loss: 8.08, valid ACC: 9.65e-01
epoch: 20, lr: 4.17e-04, steps: 35940, optimizer: Adam - train loss: 13.48 - valid loss: 7.77, valid ACC: 9.66e-01, valid WER: 6.04
epoch: 21, lr: 4.07e-04, steps: 37737, optimizer: Adam - train loss: 12.74 - valid loss: 8.06, valid ACC: 9.66e-01
epoch: 22, lr: 3.98e-04, steps: 39534, optimizer: Adam - train loss: 12.14 - valid loss: 7.56, valid ACC: 9.68e-01
epoch: 23, lr: 3.89e-04, steps: 41331, optimizer: Adam - train loss: 11.59 - valid loss: 7.79, valid ACC: 9.68e-01
epoch: 24, lr: 3.81e-04, steps: 43128, optimizer: Adam - train loss: 11.11 - valid loss: 7.39, valid ACC: 9.68e-01
epoch: 25, lr: 3.73e-04, steps: 44925, optimizer: Adam - train loss: 10.68 - valid loss: 7.27, valid ACC: 9.69e-01
epoch: 26, lr: 3.66e-04, steps: 46722, optimizer: Adam - train loss: 10.28 - valid loss: 7.51, valid ACC: 9.68e-01
epoch: 27, lr: 3.59e-04, steps: 48519, optimizer: Adam - train loss: 9.86 - valid loss: 6.90, valid ACC: 9.70e-01
epoch: 28, lr: 3.52e-04, steps: 50316, optimizer: Adam - train loss: 9.56 - valid loss: 7.27, valid ACC: 9.69e-01
epoch: 29, lr: 3.46e-04, steps: 52113, optimizer: Adam - train loss: 9.18 - valid loss: 7.06, valid ACC: 9.70e-01
epoch: 30, lr: 3.40e-04, steps: 53910, optimizer: Adam - train loss: 8.94 - valid loss: 7.38, valid ACC: 9.70e-01, valid WER: 8.63
epoch: 31, lr: 3.35e-04, steps: 55707, optimizer: Adam - train loss: 8.63 - valid loss: 6.77, valid ACC: 9.72e-01
epoch: 32, lr: 3.30e-04, steps: 57504, optimizer: Adam - train loss: 8.44 - valid loss: 7.10, valid ACC: 9.70e-01
epoch: 33, lr: 3.25e-04, steps: 59301, optimizer: Adam - train loss: 8.17 - valid loss: 6.92, valid ACC: 9.71e-01
epoch: 34, lr: 3.20e-04, steps: 61098, optimizer: Adam - train loss: 7.98 - valid loss: 6.98, valid ACC: 9.71e-01
epoch: 35, lr: 3.15e-04, steps: 62895, optimizer: Adam - train loss: 7.83 - valid loss: 6.98, valid ACC: 9.71e-01
epoch: 36, lr: 3.11e-04, steps: 64692, optimizer: Adam - train loss: 7.61 - valid loss: 7.13, valid ACC: 9.71e-01
epoch: 37, lr: 3.07e-04, steps: 66489, optimizer: Adam - train loss: 7.37 - valid loss: 6.85, valid ACC: 9.72e-01
epoch: 38, lr: 3.03e-04, steps: 68286, optimizer: Adam - train loss: 7.25 - valid loss: 6.79, valid ACC: 9.72e-01
epoch: 39, lr: 2.99e-04, steps: 70083, optimizer: Adam - train loss: 7.12 - valid loss: 6.96, valid ACC: 9.71e-01
epoch: 40, lr: 2.95e-04, steps: 71880, optimizer: Adam - train loss: 7.00 - valid loss: 7.03, valid ACC: 9.72e-01, valid WER: 6.01
POST-FIX TRAIN LOG:
epoch: 1, lr: 3.59e-05, steps: 1797, optimizer: Adam - train loss: 3.07e+02 - valid loss: 2.09e+02, valid ACC: 1.79e-01
epoch: 2, lr: 7.19e-05, steps: 3594, optimizer: Adam - train loss: 2.57e+02 - valid loss: 1.86e+02, valid ACC: 2.12e-01
epoch: 3, lr: 1.08e-04, steps: 5391, optimizer: Adam - train loss: 2.01e+02 - valid loss: 96.83, valid ACC: 5.54e-01
epoch: 4, lr: 1.44e-04, steps: 7188, optimizer: Adam - train loss: 97.51 - valid loss: 37.82, valid ACC: 8.36e-01
epoch: 5, lr: 1.80e-04, steps: 8985, optimizer: Adam - train loss: 58.00 - valid loss: 24.42, valid ACC: 8.93e-01
epoch: 6, lr: 2.16e-04, steps: 10782, optimizer: Adam - train loss: 44.37 - valid loss: 18.93, valid ACC: 9.20e-01
epoch: 7, lr: 2.52e-04, steps: 12579, optimizer: Adam - train loss: 36.70 - valid loss: 16.06, valid ACC: 9.33e-01
epoch: 8, lr: 2.88e-04, steps: 14376, optimizer: Adam - train loss: 31.20 - valid loss: 13.82, valid ACC: 9.39e-01
epoch: 9, lr: 3.23e-04, steps: 16173, optimizer: Adam - train loss: 27.64 - valid loss: 12.43, valid ACC: 9.45e-01
epoch: 10, lr: 3.59e-04, steps: 17970, optimizer: Adam - train loss: 25.32 - valid loss: 11.77, valid ACC: 9.49e-01, valid WER: 6.33
epoch: 11, lr: 3.95e-04, steps: 19767, optimizer: Adam - train loss: 23.50 - valid loss: 10.55, valid ACC: 9.53e-01
epoch: 12, lr: 4.31e-04, steps: 21564, optimizer: Adam - train loss: 22.25 - valid loss: 10.52, valid ACC: 9.53e-01
epoch: 13, lr: 4.67e-04, steps: 23361, optimizer: Adam - train loss: 20.92 - valid loss: 9.84, valid ACC: 9.55e-01
epoch: 14, lr: 4.98e-04, steps: 25158, optimizer: Adam - train loss: 20.06 - valid loss: 9.40, valid ACC: 9.57e-01
epoch: 15, lr: 4.82e-04, steps: 26955, optimizer: Adam - train loss: 18.82 - valid loss: 9.19, valid ACC: 9.60e-01
epoch: 16, lr: 4.66e-04, steps: 28752, optimizer: Adam - train loss: 17.32 - valid loss: 9.02, valid ACC: 9.61e-01
epoch: 17, lr: 4.52e-04, steps: 30549, optimizer: Adam - train loss: 15.80 - valid loss: 7.95, valid ACC: 9.64e-01
epoch: 18, lr: 4.40e-04, steps: 32346, optimizer: Adam - train loss: 14.76 - valid loss: 8.43, valid ACC: 9.64e-01
epoch: 19, lr: 4.28e-04, steps: 34143, optimizer: Adam - train loss: 13.90 - valid loss: 8.03, valid ACC: 9.65e-01
epoch: 20, lr: 4.17e-04, steps: 35940, optimizer: Adam - train loss: 13.06 - valid loss: 7.65, valid ACC: 9.67e-01, valid WER: 4.19
epoch: 21, lr: 4.07e-04, steps: 37737, optimizer: Adam - train loss: 12.36 - valid loss: 7.65, valid ACC: 9.67e-01
epoch: 22, lr: 3.98e-04, steps: 39534, optimizer: Adam - train loss: 11.79 - valid loss: 7.17, valid ACC: 9.69e-01
epoch: 23, lr: 3.89e-04, steps: 41331, optimizer: Adam - train loss: 11.22 - valid loss: 7.13, valid ACC: 9.69e-01
epoch: 24, lr: 3.81e-04, steps: 43128, optimizer: Adam - train loss: 10.69 - valid loss: 7.51, valid ACC: 9.70e-01
epoch: 25, lr: 3.73e-04, steps: 44925, optimizer: Adam - train loss: 10.25 - valid loss: 7.07, valid ACC: 9.70e-01
epoch: 26, lr: 3.66e-04, steps: 46722, optimizer: Adam - train loss: 9.89 - valid loss: 7.32, valid ACC: 9.70e-01
epoch: 27, lr: 3.59e-04, steps: 48519, optimizer: Adam - train loss: 9.48 - valid loss: 7.59, valid ACC: 9.70e-01
epoch: 28, lr: 3.52e-04, steps: 50316, optimizer: Adam - train loss: 9.19 - valid loss: 6.93, valid ACC: 9.71e-01
epoch: 29, lr: 3.46e-04, steps: 52113, optimizer: Adam - train loss: 8.88 - valid loss: 7.61, valid ACC: 9.71e-01
epoch: 30, lr: 3.40e-04, steps: 53910, optimizer: Adam - train loss: 8.63 - valid loss: 7.05, valid ACC: 9.71e-01, valid WER: 3.66
epoch: 31, lr: 3.35e-04, steps: 55707, optimizer: Adam - train loss: 8.32 - valid loss: 6.92, valid ACC: 9.72e-01
epoch: 32, lr: 3.30e-04, steps: 57504, optimizer: Adam - train loss: 8.09 - valid loss: 7.27, valid ACC: 9.71e-01
epoch: 33, lr: 3.25e-04, steps: 59301, optimizer: Adam - train loss: 7.86 - valid loss: 7.48, valid ACC: 9.71e-01
epoch: 34, lr: 3.20e-04, steps: 61098, optimizer: Adam - train loss: 7.66 - valid loss: 7.32, valid ACC: 9.72e-01
epoch: 35, lr: 3.15e-04, steps: 62895, optimizer: Adam - train loss: 7.51 - valid loss: 7.33, valid ACC: 9.72e-01
epoch: 36, lr: 3.11e-04, steps: 64692, optimizer: Adam - train loss: 7.29 - valid loss: 7.33, valid ACC: 9.72e-01
epoch: 37, lr: 3.07e-04, steps: 66489, optimizer: Adam - train loss: 7.09 - valid loss: 7.08, valid ACC: 9.73e-01
epoch: 38, lr: 3.03e-04, steps: 68286, optimizer: Adam - train loss: 6.99 - valid loss: 7.46, valid ACC: 9.72e-01
epoch: 39, lr: 2.99e-04, steps: 70083, optimizer: Adam - train loss: 6.89 - valid loss: 7.24, valid ACC: 9.72e-01
epoch: 40, lr: 2.95e-04, steps: 71880, optimizer: Adam - train loss: 6.77 - valid loss: 7.41, valid ACC: 9.73e-01, valid WER: 3.40
I also tested a conformer_small to see if removing this would affect the performance of previous model, and I got 2.5 which is equivalent to our current small conformer performance.