fix: dtype mismatch in matmul_lora and LoRA backward with bnb-4bit + GRPO#4918
fix: dtype mismatch in matmul_lora and LoRA backward with bnb-4bit + GRPO#4918anandn1 wants to merge 6 commits intounslothai:mainfrom
Conversation
…GRPO
Two root causes fixed:
1. matmul_lora (utils.py): fast_dequantize reads from a global buffer
whose dtype is controlled by quant_state.dtype embedded in bnb-4bit
checkpoints (typically float16), not by the dtype= arg passed at load
time. When activations are bfloat16, the subsequent matmul crashes with
"got Half and BFloat16". Fix: cast W to activation dtype after
fast_dequantize. Same cast applied to `out` in the LoRA branch.
2. LoRA_MLP_SwiGLU.backward and LoRA_QKV.backward (fast_lora.py):
@torch_amp_custom_bwd inherits the float16 autocast context established
by TRL's compiled GRPO trainer. This silently downcasts float32 gradient
tensors (dY, dQ/dK/dV) to float16 mid-computation, causing addmm_
dtype mismatches. Fix: wrap entire backward body in
torch.amp.autocast("cuda", enabled=False) and explicitly cast all
incoming gradient tensors and dequantized base weights to X.dtype.
Reproducer: Llama-3.2-3B-Instruct-bnb-4bit + GRPO, bf16=True, fp16=False,
unsloth 2026.4.4, TRL 0.22.x, CUDA 12.8.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6b6b5d83d9
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
There was a problem hiding this comment.
Code Review
This pull request implements explicit dtype management and disables autocast during the backward pass of LoRA kernels to resolve addmm_ mismatches caused by silent downcasting in certain training environments. The feedback focuses on improving device compatibility by replacing hardcoded "cuda" strings with dynamic device types and removing several redundant type casts for tensors that already match the target activation dtype.
I am having trouble creating individual review comments. Click here to see my feedback.
unsloth/kernels/fast_lora.py (142)
Hardcoding "cuda" in torch.amp.autocast limits compatibility with other devices. Using X.device.type makes the context manager device-agnostic.
with torch.amp.autocast(X.device.type, enabled=False):
unsloth/kernels/fast_lora.py (144-146)
The casts for e and g are redundant because they are produced in the forward pass using matmul_lora, which already returns tensors in the activation dtype. Only the cast for dY is necessary.
if dY.dtype != dtype: dY = dY.to(dtype)
unsloth/kernels/fast_lora.py (168-171)
These casts are redundant as DW is already in the correct dtype.
h, df, de = DW, e, g
unsloth/kernels/fast_lora.py (462)
Hardcoding "cuda" in torch.amp.autocast limits compatibility with other devices. Using X.device.type makes the context manager device-agnostic.
with torch.amp.autocast(X.device.type, enabled=False):
unsloth/kernels/utils.py (1051)
X.to(dtype) is redundant here because dtype is defined as X.dtype at the start of the function.
XA = torch_matmul(X, A.to(dtype))
There was a problem hiding this comment.
Pull request overview
Fixes mixed-dtype failures when using bitsandbytes 4-bit checkpoints with bfloat16 activations (notably during GRPO training), by enforcing dtype alignment in the LoRA matmul path and preventing inherited autocast contexts from downcasting gradients during custom backward passes.
Changes:
- Cast dequantized 4-bit base weights in
matmul_lorato match activation dtype before matmul/addmm. - In LoRA custom backward functions, disable autocast and explicitly cast incoming gradients and dequantized weights to
X.dtypeto avoidaddmm_dtype mismatches.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
unsloth/kernels/utils.py |
Adds dtype-normalization around dequantized weights and LoRA accumulation in matmul_lora. |
unsloth/kernels/fast_lora.py |
Disables autocast in LoRA custom backward passes and aligns gradient/weight dtypes to the activation dtype. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| W = fast_dequantize(W, W_quant, use_global_buffer = True) | ||
| # fast_dequantize may return float16 from the global buffer even when | ||
| # activations are bfloat16 (quant_state.dtype from bnb-4bit checkpoints). | ||
| # Cast W to match the activation dtype to avoid mixed-dtype matmul errors. | ||
| if W.dtype != dtype: | ||
| W = W.to(dtype) | ||
| out = torch_matmul(X, W.t(), out = out) | ||
| if W_quant is not None: |
There was a problem hiding this comment.
This fixes dtype mismatches in matmul_lora, but the same root cause (4-bit quant_state.dtype driving dequant/output dtype) still exists in other hot paths in this file. For example, fast_linear_forward dequantizes with use_global_buffer=True and then calls torch_matmul(X, W, out=out) (and fast_gemv selects fp16/bf16 kernels based on quant_state.dtype), so dtype=torch.bfloat16 activations with fp16 quant_state.dtype can still trigger dtype-mismatch errors or incorrect kernel selection when q_len == 1 (especially for bsz > 1). Consider applying the same dtype-alignment strategy in those q_len==1 paths as well (cast the dequantized weight/output and/or ensure X is in the expected dtype before calling the bnb kernels).
| # Cast W to match the activation dtype to avoid mixed-dtype matmul errors. | ||
| if W.dtype != dtype: | ||
| W = W.to(dtype) | ||
| out = torch_matmul(X, W.t(), out = out) |
There was a problem hiding this comment.
Casting the full dequantized weight matrix with W = W.to(dtype) can allocate a fresh copy every call when quant_state.dtype != X.dtype, which may negate much of the benefit of use_global_buffer=True and increase memory bandwidth/peak memory. If this mismatch is expected to be common (e.g., bf16 activations + fp16 bnb checkpoints), consider a longer-term approach where the global dequant buffer is keyed by the compute/activation dtype (or otherwise dequantizes directly into the desired dtype) to avoid repeated full-matrix casts.
| # Cast W to match the activation dtype to avoid mixed-dtype matmul errors. | |
| if W.dtype != dtype: | |
| W = W.to(dtype) | |
| out = torch_matmul(X, W.t(), out = out) | |
| # Avoid casting the full dequantized weight matrix every call; instead, | |
| # run the base matmul in W.dtype and cast the result back if needed. | |
| base_X = X if W.dtype == dtype else X.to(W.dtype) | |
| base_out = out if (out is None or out.dtype == base_X.dtype) else None | |
| out = torch_matmul(base_X, W.t(), out = base_out) | |
| if out.dtype != dtype: | |
| out = out.to(dtype) |
Replace hardcoded torch.amp.autocast("cuda", enabled=False) with
torch.amp.autocast(X.device.type, enabled=False) in all three LoRA
backward methods (LoRA_MLP_SwiGLU, LoRA_QKV, LoRA_W).
The file already switches torch_amp_custom_bwd to device_type="xpu" on
Intel XPU (utils.py:53-55). Hardcoding "cuda" in the autocast guard
would target the wrong context on XPU and may error on systems without
CUDA. Deriving the device type from the input tensor makes the fix
backend-agnostic.
for more information, see https://pre-commit.ci
|
Hey @anandn1 while the root cause analysis seems to be plausible, the fix doesn't seem to be optimal. It is much easier and cleaner to fix the argument to show bfloat16 instead of having to typecast activations in the kernels. This is definitely not the ideal way as it would cost us performance. |
Add W_quant.dtype = dtype assignment before each fast_dequantize call in matmul_lora (forward) and all three LoRA backward methods (MLP, QKV, W). This ensures fast_dequantize selects the correct NF4 kernel (cdequantize_blockwise_bf16_nf4 vs fp16_nf4) and allocates the output buffer in the activation dtype directly, eliminating unnecessary kernel path divergence. Safety post-casts are retained as a fallback. Addresses review feedback from @Datta0 on PR unslothai#4918.
for more information, see https://pre-commit.ci
|
Thanks for the review @Datta0! Updated the approach in the latest commit. Instead of casting after dequantization, we now set The previous post-cast ( Applied at 4 sites: |
|
I am not very keen on such a safety net which fails silently tbh. If an error is not supposed to happen, we should report it and not fail silently, especially hurting performance. |
|
@Datta0 Ohk, I'll implement the suggested approach. Feedback appreciated ! |
Problem
Fixes #4891.
GRPO training with
unsloth/Llama-3.2-3B-Instruct-bnb-4bit(bf16 activations) crashes with two distinct dtype mismatch errors, even after settingdtype=torch.bfloat16,bf16=True, andfp16=False:Error 1 — forward pass (
utils.py:matmul_lora, LoRA branch)Error 2 — backward pass (
fast_lora.py,LoRA_MLP_SwiGLU/LoRA_QKV)Root Causes
1.
fast_dequantizereturns float16 regardless of load-timedtype=fast_dequantizewrites into a globalWEIGHT_BUFFERScache whose dtype is controlled byquant_state.dtypeembedded in the bnb-4bit checkpoint — float16 by default. Settingdtype=torch.bfloat16at load time does not updatequant_state.dtype, so the dequantized base weightWcomes out float16 even when activations are bfloat16. Theouttensor from the base-weight matmul is therefore float16, and the subsequentout.addmm_(XA, B.to(dtype))in the LoRA branch crashes becauseoutis float16 butB.to(dtype)is bfloat16.2.
@torch_amp_custom_bwdinherits TRL's float16 autocast contextTRL's compiled GRPO trainer establishes a float16
autocastcontext for parts of training.@torch_amp_custom_bwdre-enters that same context during the custom backward pass, silently downcasting float32 gradient tensors (dY,dQ/dK/dV) to float16 mid-computation. The subsequentaddmm_calls see mixed float32/float16 operands and crash.Fixes
unsloth/kernels/utils.py—matmul_lorafast_dequantize, castWto the activation dtype if they differ.outandXto the activation dtype beforeaddmm_, ensuring the base-weight matmul output dtype never bleeds into the LoRA accumulation.unsloth/kernels/fast_lora.py—LoRA_MLP_SwiGLU.backwardandLoRA_QKV.backwardtorch.amp.autocast("cuda", enabled=False)to prevent the inherited float16 context from downcasting gradients.dY,dQ/dK/dV) and all dequantized base weights (upW,gateW,QW,KW,VW) toX.dtype.Relation to PR #4005
PR #4005 threads
correct_dtypethroughpatch_model_and_tokenizer. This PR addresses the remaining kernel-level gaps: the global dequant buffer dtype mismatch inmatmul_loraforward, and the autocast context inheritance in the LoRA custom backward functions.Reproducer
Environment: NVIDIA RTX 5050 Laptop GPU, CUDA 12.8, unsloth 2026.4.4, TRL 0.22.x, PyTorch 2.10