Skip to content

switch to RMSNorm from pytorch#447

Draft
mserranos wants to merge 5 commits into
mainfrom
mjs/rms_norm
Draft

switch to RMSNorm from pytorch#447
mserranos wants to merge 5 commits into
mainfrom
mjs/rms_norm

Conversation

@mserranos
Copy link
Copy Markdown
Collaborator

No description provided.

@mserranos mserranos requested a review from JRosenkranz July 10, 2025 14:36
Comment thread fms/modules/rmsnorm.py Outdated
import torch


class RMSNormFMS(torch.nn.RMSNorm):
Copy link
Copy Markdown
Collaborator

@JRosenkranz JRosenkranz Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a new class for this, we can just reuse RMSNorm in place of LayerNormParameterized. The use_high_precision_pow most likely should always be True. It looks like it upcasts explicitly in the cpp as well:

https://github.com/pytorch/pytorch/blob/6ea91f067256447cda6fae533f806c1f8baafbe2/aten/src/ATen/native/layer_norm.cpp#L301

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, we could clean this up more, but let's got with the initial cleanup and once deeptools is ready we can proceed with this. #449

Mauricio J Serrano and others added 5 commits July 15, 2025 07:10
Signed-off-by: Mauricio J Serrano <mserrano@us.ibm.com>
Signed-off-by: Mauricio J Serrano <mserrano@us.ibm.com>
Signed-off-by: Mauricio J Serrano <mserrano@us.ibm.com>
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
Signed-off-by: Mauricio J Serrano <mserrano@us.ibm.com>
Signed-off-by: Mauricio J Serrano <mserrano@us.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants