Skip to content

Fix Conformer Instabilities and add Large Model#1892

Merged
TParcollet merged 25 commits intospeechbrain:developfrom
TParcollet:fix_conformer
Mar 24, 2023
Merged

Fix Conformer Instabilities and add Large Model#1892
TParcollet merged 25 commits intospeechbrain:developfrom
TParcollet:fix_conformer

Conversation

@TParcollet
Copy link
Copy Markdown
Collaborator

@TParcollet TParcollet commented Mar 22, 2023

We've been struggling with training larger conformers due to instability. This PR attempts at fixing this. In practice, for some reason, please @popcornell check, we were adding a sinusoidal pos emb to the encoder output before giving to the decoder. We don't do this for the standard transformer, so I removed it and trained a large model.

I also added the support of bfloat16 training in the core.py. This is a bit redundant with --auto_mix_prec though. @anautsch what do you think? The problem is that the best thing would be to rename --auto_mix_prec into --fp16_mix_prec but this would affect soooooooooooo many files ... So maybe keeping --auto_mix_prec and --bfloat16_mix_prec will do :-)? @pplantinga what do you think also?

PRE-FIX TRAIN_LOG
epoch: 1, lr: 3.59e-05, steps: 1797, optimizer: Adam - train loss: 3.06e+02 - valid loss: 2.09e+02, valid ACC: 1.83e-01
epoch: 2, lr: 7.19e-05, steps: 3594, optimizer: Adam - train loss: 2.54e+02 - valid loss: 1.69e+02, valid ACC: 2.58e-01
epoch: 3, lr: 1.08e-04, steps: 5391, optimizer: Adam - train loss: 1.66e+02 - valid loss: 66.64, valid ACC: 6.99e-01
epoch: 4, lr: 1.44e-04, steps: 7188, optimizer: Adam - train loss: 80.97 - valid loss: 33.55, valid ACC: 8.58e-01
epoch: 5, lr: 1.80e-04, steps: 8985, optimizer: Adam - train loss: 52.80 - valid loss: 23.88, valid ACC: 9.00e-01
epoch: 6, lr: 2.16e-04, steps: 10782, optimizer: Adam - train loss: 41.22 - valid loss: 19.67, valid ACC: 9.23e-01
epoch: 7, lr: 2.52e-04, steps: 12579, optimizer: Adam - train loss: 34.78 - valid loss: 16.37, valid ACC: 9.33e-01
epoch: 8, lr: 2.88e-04, steps: 14376, optimizer: Adam - train loss: 30.60 - valid loss: 15.25, valid ACC: 9.39e-01
epoch: 9, lr: 3.23e-04, steps: 16173, optimizer: Adam - train loss: 27.49 - valid loss: 13.32, valid ACC: 9.45e-01
epoch: 10, lr: 3.59e-04, steps: 17970, optimizer: Adam - train loss: 25.36 - valid loss: 12.12, valid ACC: 9.48e-01, valid WER: 6.76
epoch: 11, lr: 3.95e-04, steps: 19767, optimizer: Adam - train loss: 23.71 - valid loss: 11.84, valid ACC: 9.50e-01
epoch: 12, lr: 4.31e-04, steps: 21564, optimizer: Adam - train loss: 22.71 - valid loss: 10.87, valid ACC: 9.52e-01
epoch: 13, lr: 4.67e-04, steps: 23361, optimizer: Adam - train loss: 21.46 - valid loss: 10.01, valid ACC: 9.54e-01
epoch: 14, lr: 4.98e-04, steps: 25158, optimizer: Adam - train loss: 20.59 - valid loss: 10.16, valid ACC: 9.55e-01
epoch: 15, lr: 4.82e-04, steps: 26955, optimizer: Adam - train loss: 19.34 - valid loss: 9.60, valid ACC: 9.57e-01
epoch: 16, lr: 4.66e-04, steps: 28752, optimizer: Adam - train loss: 17.85 - valid loss: 8.99, valid ACC: 9.60e-01
epoch: 17, lr: 4.52e-04, steps: 30549, optimizer: Adam - train loss: 16.33 - valid loss: 8.85, valid ACC: 9.62e-01
epoch: 18, lr: 4.40e-04, steps: 32346, optimizer: Adam - train loss: 15.22 - valid loss: 8.10, valid ACC: 9.64e-01
epoch: 19, lr: 4.28e-04, steps: 34143, optimizer: Adam - train loss: 14.33 - valid loss: 8.08, valid ACC: 9.65e-01
epoch: 20, lr: 4.17e-04, steps: 35940, optimizer: Adam - train loss: 13.48 - valid loss: 7.77, valid ACC: 9.66e-01, valid WER: 6.04
epoch: 21, lr: 4.07e-04, steps: 37737, optimizer: Adam - train loss: 12.74 - valid loss: 8.06, valid ACC: 9.66e-01
epoch: 22, lr: 3.98e-04, steps: 39534, optimizer: Adam - train loss: 12.14 - valid loss: 7.56, valid ACC: 9.68e-01
epoch: 23, lr: 3.89e-04, steps: 41331, optimizer: Adam - train loss: 11.59 - valid loss: 7.79, valid ACC: 9.68e-01
epoch: 24, lr: 3.81e-04, steps: 43128, optimizer: Adam - train loss: 11.11 - valid loss: 7.39, valid ACC: 9.68e-01
epoch: 25, lr: 3.73e-04, steps: 44925, optimizer: Adam - train loss: 10.68 - valid loss: 7.27, valid ACC: 9.69e-01
epoch: 26, lr: 3.66e-04, steps: 46722, optimizer: Adam - train loss: 10.28 - valid loss: 7.51, valid ACC: 9.68e-01
epoch: 27, lr: 3.59e-04, steps: 48519, optimizer: Adam - train loss: 9.86 - valid loss: 6.90, valid ACC: 9.70e-01
epoch: 28, lr: 3.52e-04, steps: 50316, optimizer: Adam - train loss: 9.56 - valid loss: 7.27, valid ACC: 9.69e-01
epoch: 29, lr: 3.46e-04, steps: 52113, optimizer: Adam - train loss: 9.18 - valid loss: 7.06, valid ACC: 9.70e-01
epoch: 30, lr: 3.40e-04, steps: 53910, optimizer: Adam - train loss: 8.94 - valid loss: 7.38, valid ACC: 9.70e-01, valid WER: 8.63
epoch: 31, lr: 3.35e-04, steps: 55707, optimizer: Adam - train loss: 8.63 - valid loss: 6.77, valid ACC: 9.72e-01
epoch: 32, lr: 3.30e-04, steps: 57504, optimizer: Adam - train loss: 8.44 - valid loss: 7.10, valid ACC: 9.70e-01
epoch: 33, lr: 3.25e-04, steps: 59301, optimizer: Adam - train loss: 8.17 - valid loss: 6.92, valid ACC: 9.71e-01
epoch: 34, lr: 3.20e-04, steps: 61098, optimizer: Adam - train loss: 7.98 - valid loss: 6.98, valid ACC: 9.71e-01
epoch: 35, lr: 3.15e-04, steps: 62895, optimizer: Adam - train loss: 7.83 - valid loss: 6.98, valid ACC: 9.71e-01
epoch: 36, lr: 3.11e-04, steps: 64692, optimizer: Adam - train loss: 7.61 - valid loss: 7.13, valid ACC: 9.71e-01
epoch: 37, lr: 3.07e-04, steps: 66489, optimizer: Adam - train loss: 7.37 - valid loss: 6.85, valid ACC: 9.72e-01
epoch: 38, lr: 3.03e-04, steps: 68286, optimizer: Adam - train loss: 7.25 - valid loss: 6.79, valid ACC: 9.72e-01
epoch: 39, lr: 2.99e-04, steps: 70083, optimizer: Adam - train loss: 7.12 - valid loss: 6.96, valid ACC: 9.71e-01
epoch: 40, lr: 2.95e-04, steps: 71880, optimizer: Adam - train loss: 7.00 - valid loss: 7.03, valid ACC: 9.72e-01, valid WER: 6.01

POST-FIX TRAIN LOG:
epoch: 1, lr: 3.59e-05, steps: 1797, optimizer: Adam - train loss: 3.07e+02 - valid loss: 2.09e+02, valid ACC: 1.79e-01
epoch: 2, lr: 7.19e-05, steps: 3594, optimizer: Adam - train loss: 2.57e+02 - valid loss: 1.86e+02, valid ACC: 2.12e-01
epoch: 3, lr: 1.08e-04, steps: 5391, optimizer: Adam - train loss: 2.01e+02 - valid loss: 96.83, valid ACC: 5.54e-01
epoch: 4, lr: 1.44e-04, steps: 7188, optimizer: Adam - train loss: 97.51 - valid loss: 37.82, valid ACC: 8.36e-01
epoch: 5, lr: 1.80e-04, steps: 8985, optimizer: Adam - train loss: 58.00 - valid loss: 24.42, valid ACC: 8.93e-01
epoch: 6, lr: 2.16e-04, steps: 10782, optimizer: Adam - train loss: 44.37 - valid loss: 18.93, valid ACC: 9.20e-01
epoch: 7, lr: 2.52e-04, steps: 12579, optimizer: Adam - train loss: 36.70 - valid loss: 16.06, valid ACC: 9.33e-01
epoch: 8, lr: 2.88e-04, steps: 14376, optimizer: Adam - train loss: 31.20 - valid loss: 13.82, valid ACC: 9.39e-01
epoch: 9, lr: 3.23e-04, steps: 16173, optimizer: Adam - train loss: 27.64 - valid loss: 12.43, valid ACC: 9.45e-01
epoch: 10, lr: 3.59e-04, steps: 17970, optimizer: Adam - train loss: 25.32 - valid loss: 11.77, valid ACC: 9.49e-01, valid WER: 6.33
epoch: 11, lr: 3.95e-04, steps: 19767, optimizer: Adam - train loss: 23.50 - valid loss: 10.55, valid ACC: 9.53e-01
epoch: 12, lr: 4.31e-04, steps: 21564, optimizer: Adam - train loss: 22.25 - valid loss: 10.52, valid ACC: 9.53e-01
epoch: 13, lr: 4.67e-04, steps: 23361, optimizer: Adam - train loss: 20.92 - valid loss: 9.84, valid ACC: 9.55e-01
epoch: 14, lr: 4.98e-04, steps: 25158, optimizer: Adam - train loss: 20.06 - valid loss: 9.40, valid ACC: 9.57e-01
epoch: 15, lr: 4.82e-04, steps: 26955, optimizer: Adam - train loss: 18.82 - valid loss: 9.19, valid ACC: 9.60e-01
epoch: 16, lr: 4.66e-04, steps: 28752, optimizer: Adam - train loss: 17.32 - valid loss: 9.02, valid ACC: 9.61e-01
epoch: 17, lr: 4.52e-04, steps: 30549, optimizer: Adam - train loss: 15.80 - valid loss: 7.95, valid ACC: 9.64e-01
epoch: 18, lr: 4.40e-04, steps: 32346, optimizer: Adam - train loss: 14.76 - valid loss: 8.43, valid ACC: 9.64e-01
epoch: 19, lr: 4.28e-04, steps: 34143, optimizer: Adam - train loss: 13.90 - valid loss: 8.03, valid ACC: 9.65e-01
epoch: 20, lr: 4.17e-04, steps: 35940, optimizer: Adam - train loss: 13.06 - valid loss: 7.65, valid ACC: 9.67e-01, valid WER: 4.19
epoch: 21, lr: 4.07e-04, steps: 37737, optimizer: Adam - train loss: 12.36 - valid loss: 7.65, valid ACC: 9.67e-01
epoch: 22, lr: 3.98e-04, steps: 39534, optimizer: Adam - train loss: 11.79 - valid loss: 7.17, valid ACC: 9.69e-01
epoch: 23, lr: 3.89e-04, steps: 41331, optimizer: Adam - train loss: 11.22 - valid loss: 7.13, valid ACC: 9.69e-01
epoch: 24, lr: 3.81e-04, steps: 43128, optimizer: Adam - train loss: 10.69 - valid loss: 7.51, valid ACC: 9.70e-01
epoch: 25, lr: 3.73e-04, steps: 44925, optimizer: Adam - train loss: 10.25 - valid loss: 7.07, valid ACC: 9.70e-01
epoch: 26, lr: 3.66e-04, steps: 46722, optimizer: Adam - train loss: 9.89 - valid loss: 7.32, valid ACC: 9.70e-01
epoch: 27, lr: 3.59e-04, steps: 48519, optimizer: Adam - train loss: 9.48 - valid loss: 7.59, valid ACC: 9.70e-01
epoch: 28, lr: 3.52e-04, steps: 50316, optimizer: Adam - train loss: 9.19 - valid loss: 6.93, valid ACC: 9.71e-01
epoch: 29, lr: 3.46e-04, steps: 52113, optimizer: Adam - train loss: 8.88 - valid loss: 7.61, valid ACC: 9.71e-01
epoch: 30, lr: 3.40e-04, steps: 53910, optimizer: Adam - train loss: 8.63 - valid loss: 7.05, valid ACC: 9.71e-01, valid WER: 3.66
epoch: 31, lr: 3.35e-04, steps: 55707, optimizer: Adam - train loss: 8.32 - valid loss: 6.92, valid ACC: 9.72e-01
epoch: 32, lr: 3.30e-04, steps: 57504, optimizer: Adam - train loss: 8.09 - valid loss: 7.27, valid ACC: 9.71e-01
epoch: 33, lr: 3.25e-04, steps: 59301, optimizer: Adam - train loss: 7.86 - valid loss: 7.48, valid ACC: 9.71e-01
epoch: 34, lr: 3.20e-04, steps: 61098, optimizer: Adam - train loss: 7.66 - valid loss: 7.32, valid ACC: 9.72e-01
epoch: 35, lr: 3.15e-04, steps: 62895, optimizer: Adam - train loss: 7.51 - valid loss: 7.33, valid ACC: 9.72e-01
epoch: 36, lr: 3.11e-04, steps: 64692, optimizer: Adam - train loss: 7.29 - valid loss: 7.33, valid ACC: 9.72e-01
epoch: 37, lr: 3.07e-04, steps: 66489, optimizer: Adam - train loss: 7.09 - valid loss: 7.08, valid ACC: 9.73e-01
epoch: 38, lr: 3.03e-04, steps: 68286, optimizer: Adam - train loss: 6.99 - valid loss: 7.46, valid ACC: 9.72e-01
epoch: 39, lr: 2.99e-04, steps: 70083, optimizer: Adam - train loss: 6.89 - valid loss: 7.24, valid ACC: 9.72e-01
epoch: 40, lr: 2.95e-04, steps: 71880, optimizer: Adam - train loss: 6.77 - valid loss: 7.41, valid ACC: 9.73e-01, valid WER: 3.40

I also tested a conformer_small to see if removing this would affect the performance of previous model, and I got 2.5 which is equivalent to our current small conformer performance.

else:
outputs = self.compute_forward(batch, Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
if self.bfloat16_mix_prec:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think something like this is a good idea. This is not a necessary change, but a slightly more succinct way to write this would be something like:

with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=self.bfloat16_mix_prec):
    outputs = self.compute_forward(batch, Stage.TRAIN)
   loss = self.compute_objectives(outputs, batch, Stage.TRAIN)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah true, I'll fix this.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@popcornell
Copy link
Copy Markdown
Collaborator

Did you still use the relative positional encoding attention ?
It is strange why adding positional encodings brings such WER degradation.
And based on your previous results it did not affect CTC but only the decoder.

encoder_out = encoder_out + self.positional_encoding_decoder(
encoder_out
)
# encoder_out = encoder_out + self.positional_encoding_decoder(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we could just remove it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@TParcollet
Copy link
Copy Markdown
Collaborator Author

Yes I don't know honestly, I just saw this as the only difference between Transformer and Conformer, removed it, tried with a bigger model and it was stable...

)
else:
outputs = self.compute_forward(batch, sb.Stage.TRAIN)
loss = self.compute_objectives(outputs, batch, sb.Stage.TRAIN)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other recipes, too

grep -rn auto_mix_prec recipes | grep py

Is there another way of handling this?
The mixed precision structure/logic is indepdent of the actual steps taken.

How about putting everything below the logic steps into different core functions that can be inherited & overwritten?

                    outputs = self.compute_forward(batch, sb.Stage.TRAIN)
                    loss = self.compute_objectives(
                        outputs, batch, sb.Stage.TRAIN
                    )

does not change, regardless of float precision. I get the copy/paste is easy but with more nested logics that with ...: wrap the same code blocks ... it's prone to making mistakes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will affect the many other files, in one way or another. Bc of that, putting most of this logic into the core module would be more practical (thinking about future maintenance running into this again).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recipes that have overloaded the fit_batch won't benefit from this change. It is expected, and I am not going to refactor this. It is a deliberate design choice to override the fit_batch function... And unfortunately, it is impossible to wrap this mixed precision stuff outside of fit batch because it is linked to the forward / backward passes. Also some operation are not stable with mixed precision, so it must be changed by the user in the fit_batch if needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm suggesting to split the fit_batch up further for more detailed handling of this logic, since this is clearly a pytorch wrapping issue. It will re-appear.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases the newly added on_fit_batch_end can address this, see discussion at #1864 (comment)

# encoder_out
# )
pos_embs_encoder = None # self.positional_encoding(src)
pos_embs_target = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

positional_encoding_decoder is only set for attention_type == "RelPosMHAXL"

        if attention_type == "RelPosMHAXL":
            self.positional_encoding = RelPosEncXL(d_model)
            self.positional_encoding_decoder = PositionalEncoding(
                d_model, max_length
            )

So - what was this fixme about in the first place?


btw - it's also used in G2P:

grep -r positional_encoding_decoder speechbrain

What is this line about?

pos_embs_encoder = None # self.positional_encoding(src)
pos_embs_encoder = self.positional_encoding(src)  # if we drop the other one, as you suggested

Maybe the comment there can be dropped (it would relate to the forward in class RelPosMHAXL(nn.Module)).

...
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
...
p_k = self.linear_pos(pos_embs).view(
    1, -1, self.num_heads, self.head_dim
)

@anautsch
Copy link
Copy Markdown
Collaborator

At which level is the positional encoding meant to work?

The paper https://arxiv.org/pdf/1901.02860.pdf xref:ed in SB readthedocs refers in "3.3 Relative Positional Encodings" to:

where the i-th row Ui corresponds to the i-th absolute position within a segment and Lmax prescribes the maximum possible length to be modeled.

for which they add:

Notice that, both Esτ and Esτ+1 are associated with the same positional encoding U1:L. As a result, the model has no information to distinguish the positional difference between xτ,j and xτ+1,j for any j = 1, . . . , L, resulting in a sheer
performance loss.

In order to avoid this failure mode, the fundamental idea is to only encode the relative positional information in the hidden states. [...] For the same purpose, instead of incorporating bias statically into the initial embedding, one can inject the same information into the attention score of each layer. More importantly, it is more intuitive and generalizable to define the temporal bias in a relative manner.

So, the crucial thing during training is to have the relative encoding in the target?

Btw... got curious about the if/else construction:

if self.attention_type == "RelPosMHAXL":
# use standard sinusoidal pos encoding in decoder
tgt = tgt + self.positional_encoding_decoder(tgt)
# FIXME we use pos embs also on enc output
encoder_out = encoder_out + self.positional_encoding_decoder(
encoder_out
)
pos_embs_encoder = None # self.positional_encoding(src)
pos_embs_target = None
elif self.positional_encoding_type == "fixed_abs_sine":
tgt = tgt + self.positional_encoding(tgt)
pos_embs_target = None
pos_embs_encoder = None

Compare:

if positional_encoding == "fixed_abs_sine":
self.positional_encoding = PositionalEncoding(d_model, max_length)
elif positional_encoding is None:
pass
# no positional encodings
# overrides any other pos_embedding
if attention_type == "RelPosMHAXL":
self.positional_encoding = RelPosEncXL(d_model)
self.positional_encoding_decoder = PositionalEncoding(
d_model, max_length
)

and during encode():

if self.attention_type == "RelPosMHAXL":
pos_embs_source = self.positional_encoding(src)
elif self.positional_encoding_type == "fixed_abs_sine":
src = src + self.positional_encoding(src)
pos_embs_source = None


Are there tutorials for how paper write-ups relate to code?

@TParcollet
Copy link
Copy Markdown
Collaborator Author

@anautsch I cleaned this if else. The bug is just that we should not do that I guess :p

@TParcollet TParcollet merged commit 278bc6f into speechbrain:develop Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants