Rotary Position Embedding (RoPE) for ASR (code from Samsung Cambridge)#2799
Rotary Position Embedding (RoPE) for ASR (code from Samsung Cambridge)#2799Adel-Moumen merged 50 commits intospeechbrain:developfrom
Conversation
rope with the latest SB
…into rope up-to-date with develop
merge the develop into rope
|
@pplantinga @Adel-Moumen @mravanelli this is an important PR in the sense that we should, from now on, use RoPE for most models we develop. The reason is that it's definitely faster, and better. We may want to retrain some models with it ... it's also a good distinction with ESPnet and NeMO I believe. I can do a review as I am not that much aware of the code. But we'll need an external one as well. |
TParcollet
left a comment
There was a problem hiding this comment.
Thanks @shucongzhang, just a few minor comments. Also as discussed together, let see how the discussion for Torch attention vs homemade attention goes with the others.
| loss_reduction: 'batchmean' | ||
| sorting: random | ||
| num_workers: 4 | ||
| precision: fp32 # bf16, fp16 or fp32 |
There was a problem hiding this comment.
change to fp16 by default please.
| @@ -0,0 +1,338 @@ | |||
| # ############################################################################ | |||
There was a problem hiding this comment.
This guy looks VERY similar to our conformer_large.yaml, right? Why not changing the main one with RoPE. The model is just better, so I don't see any problem with that.
| raise ValueError( | ||
| "The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory" | ||
| ) | ||
| elif self.attention_type == "RoPEMHAXL": |
There was a problem hiding this comment.
Why the XL? Why not just RoPEMHA?
| value=memory, | ||
| attn_mask=memory_mask, | ||
| key_padding_mask=memory_key_padding_mask, | ||
| attn_mask=memory_mask, # none |
| pos_embs=pos_embs_src, | ||
| ) | ||
|
|
||
| # breakpoint() |
| branchformer_activation: Optional[nn.Module] = nn.GELU, | ||
| attention_type: Optional[str] = "regularMHA", | ||
| max_length: Optional[int] = 2500, | ||
| max_length: Optional[int] = 10000, |
There was a problem hiding this comment.
This breaks backward compatibility. I 100% agree that 2500 is too little, but it will break backward compatibility when loading old checkpoints... can we keep it to 2500?
| positions_inv_freq = torch.outer(positions, inv_freq) | ||
|
|
||
| cosines = torch.cos(positions_inv_freq) | ||
| # (cos(m*theta_0), cos(m*theta_0), cos(m*theta_1), cos(m*theta_1) ,... ) for equantion (34) |
| ) | ||
|
|
||
| sines = torch.sin(positions_inv_freq) | ||
| # (sin(m*theta_0), sin(m*theta_0), sin(m*theta_1), sin(m*theta_1) ,... ) for equantion (34) |
| key_padding_mask = key_padding_mask.view(bsz, 1, 1, klen).expand( | ||
| bsz, self.num_heads, klen, qlen | ||
| ) | ||
| torch.logical_not(key_padding_mask) |
There was a problem hiding this comment.
Does this line do anything? Shouldn't it be stored back into a variable?
|
I think that there is a software/hardware curiosity that could be of interest to @asumagic @pplantinga and @Adel-Moumen here. I just added a unit test comparing our homemade attention to torch attention (see the unit test folder). You will quickly see that this test passes easily when using CPU, but it fails miserably when using GPU. I am wondering if this is an actual CUDNN issue or something else? |
|
I don't really know, can it just be expected precision loss? Or are the values significantly wrong? |
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
|
@Adel-Moumen ready for the review. All the above comments can be only very briefly checked as we have worked intenally (all three of us) to make it better. You can give your input now. |
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
Make RoPE memoisation clearer (#4)
Fix comments of review
pplantinga
left a comment
There was a problem hiding this comment.
Very minor last comments cleanup
TParcollet
left a comment
There was a problem hiding this comment.
I am happy with the code. But would need a last check
|
I just ran the recipe tests and they worked for this recipe. |
What does this PR do?
This PR implements Rotary Position Embedding (RoPE) https://arxiv.org/pdf/2104.09864.
It improves the training speech of LibriSpeech by 13%, while giving slightly better ASR results.
RoPE is also useful for other datasets. We have tested it with Lirbiheavy, CommonVoice, and Voxpopuli. The full results are shown in this short paper https://arxiv.org/pdf/2501.06051. In this PR, we only submit the recipe of RoPE on LibriSpeech for simplicity. If this can be merged, we would like to also submit recipes for other datasets.