Skip to content

Rotary Position Embedding (RoPE) for ASR (code from Samsung Cambridge)#2799

Merged
Adel-Moumen merged 50 commits intospeechbrain:developfrom
shucongzhang:rope_pr
Mar 10, 2025
Merged

Rotary Position Embedding (RoPE) for ASR (code from Samsung Cambridge)#2799
Adel-Moumen merged 50 commits intospeechbrain:developfrom
shucongzhang:rope_pr

Conversation

@shucongzhang
Copy link
Copy Markdown
Contributor

What does this PR do?

This PR implements Rotary Position Embedding (RoPE) https://arxiv.org/pdf/2104.09864.

It improves the training speech of LibriSpeech by 13%, while giving slightly better ASR results.

RoPE is also useful for other datasets. We have tested it with Lirbiheavy, CommonVoice, and Voxpopuli. The full results are shown in this short paper https://arxiv.org/pdf/2501.06051. In this PR, we only submit the recipe of RoPE on LibriSpeech for simplicity. If this can be merged, we would like to also submit recipes for other datasets.

@TParcollet
Copy link
Copy Markdown
Collaborator

TParcollet commented Jan 13, 2025

@pplantinga @Adel-Moumen @mravanelli this is an important PR in the sense that we should, from now on, use RoPE for most models we develop. The reason is that it's definitely faster, and better. We may want to retrain some models with it ... it's also a good distinction with ESPnet and NeMO I believe.

I can do a review as I am not that much aware of the code. But we'll need an external one as well.

Copy link
Copy Markdown
Collaborator

@TParcollet TParcollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @shucongzhang, just a few minor comments. Also as discussed together, let see how the discussion for Torch attention vs homemade attention goes with the others.

loss_reduction: 'batchmean'
sorting: random
num_workers: 4
precision: fp32 # bf16, fp16 or fp32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to fp16 by default please.

@@ -0,0 +1,338 @@
# ############################################################################
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guy looks VERY similar to our conformer_large.yaml, right? Why not changing the main one with RoPE. The model is just better, so I don't see any problem with that.

raise ValueError(
"The chosen attention type for the Conformer is RelPosMHAXL. For this attention type, the positional embeddings are mandatory"
)
elif self.attention_type == "RoPEMHAXL":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the XL? Why not just RoPEMHA?

Comment thread speechbrain/lobes/models/transformer/Transformer.py Outdated
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
attn_mask=memory_mask, # none
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments?

pos_embs=pos_embs_src,
)

# breakpoint()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove.

branchformer_activation: Optional[nn.Module] = nn.GELU,
attention_type: Optional[str] = "regularMHA",
max_length: Optional[int] = 2500,
max_length: Optional[int] = 10000,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks backward compatibility. I 100% agree that 2500 is too little, but it will break backward compatibility when loading old checkpoints... can we keep it to 2500?

Comment thread speechbrain/nnet/attention.py
Comment thread speechbrain/nnet/attention.py Outdated
positions_inv_freq = torch.outer(positions, inv_freq)

cosines = torch.cos(positions_inv_freq)
# (cos(m*theta_0), cos(m*theta_0), cos(m*theta_1), cos(m*theta_1) ,... ) for equantion (34)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

equation (typo)

Comment thread speechbrain/nnet/attention.py Outdated
)

sines = torch.sin(positions_inv_freq)
# (sin(m*theta_0), sin(m*theta_0), sin(m*theta_1), sin(m*theta_1) ,... ) for equantion (34)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

equation (typo)

Comment thread speechbrain/nnet/attention.py Outdated
key_padding_mask = key_padding_mask.view(bsz, 1, 1, klen).expand(
bsz, self.num_heads, klen, qlen
)
torch.logical_not(key_padding_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this line do anything? Shouldn't it be stored back into a variable?

@TParcollet
Copy link
Copy Markdown
Collaborator

I think that there is a software/hardware curiosity that could be of interest to @asumagic @pplantinga and @Adel-Moumen here. I just added a unit test comparing our homemade attention to torch attention (see the unit test folder). You will quickly see that this test passes easily when using CPU, but it fails miserably when using GPU. I am wondering if this is an actual CUDNN issue or something else?

@asumagic
Copy link
Copy Markdown
Collaborator

I don't really know, can it just be expected precision loss? Or are the values significantly wrong?

Titouan Parcollet and others added 5 commits February 10, 2025 17:26
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
@pplantinga pplantinga added this to the v1.0.3 milestone Feb 14, 2025
@TParcollet
Copy link
Copy Markdown
Collaborator

@Adel-Moumen ready for the review. All the above comments can be only very briefly checked as we have worked intenally (all three of us) to make it better. You can give your input now.

Comment thread recipes/LibriSpeech/ASR/transformer/README.md Outdated
TParcollet and others added 3 commits February 25, 2025 14:51
Co-authored-by: Rogier van Dalen <r.vandalen@samsung.com>
Make RoPE memoisation clearer (#4)
@pplantinga pplantinga linked an issue Feb 26, 2025 that may be closed by this pull request
Comment thread speechbrain/nnet/attention.py Outdated
Comment thread speechbrain/nnet/attention.py Outdated
Comment thread speechbrain/nnet/attention.py Outdated
Copy link
Copy Markdown
Collaborator

@pplantinga pplantinga left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor last comments cleanup

Comment thread speechbrain/lobes/models/transformer/Transformer.py Outdated
Comment thread speechbrain/lobes/models/transformer/Transformer.py Outdated
Comment thread speechbrain/lobes/models/transformer/Transformer.py Outdated
Copy link
Copy Markdown
Collaborator

@TParcollet TParcollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy with the code. But would need a last check

@Adel-Moumen
Copy link
Copy Markdown
Collaborator

I just ran the recipe tests and they worked for this recipe.

@Adel-Moumen Adel-Moumen merged commit 7724216 into speechbrain:develop Mar 10, 2025
5 checks passed
pplantinga pushed a commit to pplantinga/speechbrain that referenced this pull request Jun 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Positional Rotary Embeddings for transformers

6 participants