Skip to content

Refactor QKV Fusion Utilities to be LoRA-Aware#14047

Open
dg845 wants to merge 7 commits into
mainfrom
refactor/lora-aware-qkv-fusion
Open

Refactor QKV Fusion Utilities to be LoRA-Aware#14047
dg845 wants to merge 7 commits into
mainfrom
refactor/lora-aware-qkv-fusion

Conversation

@dg845

@dg845 dg845 commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

What does this PR do?

This PR refactors the attention QKV fusion utilities in AttentionMixin and AttentionModuleMixin to be more LoRA-aware. In particular, this PR adds guards when attempting to fuse/unfuse with a LoRA attached (because LoRAs cannot be easily transferred over when fusing/unfusing) and an inplace option to fuse without keeping copies of the split Q,K,V projections.

Changelist

  1. Adds guards when attempting to fuse/unfuse Q,K,V with a LoRA attached (will raise an error).
  2. Adds an inplace argument; if inplace=True, the module is modified to have only the fused QKV projection (e.g. to_qkv) with the split Q,K,V projections (e.g. to_q/to_k/to_v) removed. (inplace=False, the default, retains the current behavior).
  3. Supports fusion in the case where add_k_proj and add_v_proj are present without add_q_proj also present (e.g. Wan models).
  4. Adds get_qkv and get_added_qkv helper methods in AttentionModuleMixin which handles getting the Q, K, V (and added Q,K,V, for second stream projections in MM-DiT-style models like Flux) in both the fused and split case. This is intended to make it easier for attention processors to support both fused and split QKV.
  5. Adds an experimental restore_checkpoint_fusion_state method to AttentionMixin to put models back in the fusion state of the original model checkpoint. A new _native_fused_projections attribute on AttentionModuleMixin is added to allow this state to be described. (The motivation is to make it easier to support PEFT adapters which target the original checkpoint structure.)

Partially addresses #14003.

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
@DN6

dg845 and others added 5 commits June 21, 2026 07:46
…e to original fusion state, module-level helpers to get Q,K,V in both fused and split cases
…duleMixin

Tests are in tests/models/test_attention_mixins.py and cover the four minimal
concrete AttentionModuleMixin fixtures (_MinimalSelfAttn, _MinimalCrossAttn,
_MinimalAddedKVAttn, _MinimalAddedQKVAttn):

TestAttentionModuleMixin (53 tests):
- Idempotency of fuse_projections/unfuse_projections
- Module attribute invariants for non-inplace and inplace paths
- Weight/bias correctness: fused weight equals concatenation of split weights
- Inplace round-trip weight preservation and storage-sharing (no copy on unfuse)
- Cross-attention to_kv path, added-KV to_added_kv path (Wan-style), and
  added-QKV to_added_qkv path (Flux-style)
- get_qkv and get_added_qkv numerical correctness in split and fused cases
- LoRA guard: fuse_projections/unfuse_projections raise ValueError when PEFT-style
  lora_A/lora_B submodules are detected on split or fused projections

TestAttentionMixin (6 tests):
- fuse_qkv_projections/unfuse_qkv_projections propagate to all eligible blocks
- restore_checkpoint_fusion_state respects _native_fused_projections=None/True/False
  per block, including mixed-state models

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@github-actions github-actions Bot added models tests size/L PR with diff > 200 LOC labels Jun 22, 2026
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

models size/L PR with diff > 200 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants