Skip to content

[MAX] Optimize QwenImage DiT RoPE: graph ops → fused GPU kernel #6345

Open
byungchul-sqzb wants to merge 13 commits intomodular:mainfrom
SqueezeBits:opt/qwen-image-rope-kernel
Open

[MAX] Optimize QwenImage DiT RoPE: graph ops → fused GPU kernel #6345
byungchul-sqzb wants to merge 13 commits intomodular:mainfrom
SqueezeBits:opt/qwen-image-rope-kernel

Conversation

@byungchul-sqzb
Copy link
Copy Markdown
Contributor

@byungchul-sqzb byungchul-sqzb commented Apr 3, 2026

Important

The core changes of this optimization are contained entirely within commit: e72d459.
ONLY 7 files changed (+82 -50 lines changed)

  • Please review the diff of e72d459 only (vs its parent). Merge commits / merges from modular:main may add many unrelated files — please ignore those for this review.
  • Direct link (if helpful): open commit e72d459 in GitHub and inspect “changed files” there

Summary

Optimize QwenImage DiT RoPE by replacing graph-level complex multiplication with the fused rope_ragged_with_position_ids GPU kernel.

The existing apply_rotary_emb implementation performs ~10 graph ops per call (reshape, split, cast, 4x elementwise multiply, stack, reshape, cast back). This PR replaces it with a single fused Mojo GPU kernel (mo.rope.ragged.with_position_id), following the same pattern already used by flux2 and z-image.

With 60 dual-stream blocks applying RoPE to both Q and K (120 kernel invocations per denoising step), this eliminates significant kernel launch overhead and intermediate memory allocations.

Measured improvement:

  • qwen_image (text-to-image): transformer step avg 100.6ms → 81.9ms
  • qwen_image_edit (image editing): transformer step avg 238.4ms → 178.0ms

Changes:

  • QwenImagePosEmbed: return interleaved freqs_cis tensor [S, D] instead of separate (cos, sin) tuple
  • QwenImageAttention: use rope_ragged_with_position_ids kernel via new _apply_qwen_image_qk_rope helper
  • qwen_image.py: cast freqs_cis to model dtype (matching z-image pattern)
  • Update parity test fixtures for new single-tensor RoPE interface

Both qwen_image and qwen_image_edit benefit from this change since qwen_image_edit reuses QwenImageTransformer2DModel directly.

Note on Dependencies

This PR is developed and tested on a local environment that assumes all initial Qwen-Image implementation PRs (#6139 #6140 and #6141 ) have been merged. Since those are still pending, I plan to rebase this PR once the base implementations are landed in main.

Testing

  • Ran ./bazelw run //:format — all checks passed
  • Profiled end-to-end with 100-step denoising on qwen_image and 80-step on qwen_image_edit, confirming 18.5% transformer speedup on both pipelines with no change in output quality
  • Qwen-Image
    ./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "Qwen/Qwen-Image-2512" --prompt "dog dancing near the sun" --negative-prompt " " --num-inference-steps 50 --guidance-scale 4.0 --true-cfg-scale 4.0 --seed 42 --profile-timings --num-profile-iterations 1 --num-warmups 1
...
Warmup complete
Running diffusion model...
Running inference 1 of 1
==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer              100     8318.686       83.187
component/vae.decode                 1       42.703       42.703
component/text_encoder               2       10.183        5.092

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          1     8385.103     8385.103
component/transformer              100     8318.686       83.187
decode_latents                       1       48.924       48.924
component/vae.decode                 1       42.703       42.703
prepare_embeddings                   2       10.407        5.203
component/text_encoder               2       10.183        5.092
preprocess_latents                   1        0.270        0.270
==========================================

-Qwen-Image-Edit
./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "Qwen/Qwen-Image-Edit-2511" --prompt "use soft lighting to relight the image." --negative-prompt " " --num-inference-steps 40 --guidance-scale 4.0 --true-cfg-scale 4.0 --seed 42 --profile-timings --num-profile-iterations 1 --num-warmups 1 --input-image "../input.jpeg"

...
Warmup complete
Running diffusion model...
Running inference 1 of 1
==================== PROFILING REPORT ==
Component Timings:
components                       calls        total          avg (ms)
component/transformer               80    14243.197      178.040
component/vae.decode                 1       46.668       46.668
component/vae.encode                 1       27.903       27.903

Method Timings:
methods                          calls        total          avg (ms)
E2E execute                          1    14373.698    14373.698
component/transformer               80    14243.197      178.040
decode_latents                       1       52.391       52.391
component/vae.decode                 1       46.668       46.668
prepare_image_latents                1       28.235       28.235
component/vae.encode                 1       27.903       27.903
preprocess_latents                   1        0.424        0.424
==========================================

Checklist

  • PR is small and focused — consider splitting larger changes into a
    sequence of smaller PRs
  • I ran ./bazelw run format to format my changes
  • I added or updated tests to cover my changes
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description
    (see AI Tool Use Policy)

Assisted-by: AI

jglee-sqbits and others added 12 commits March 24, 2026 08:37
# Conflicts:
#	max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py
#	max/python/max/pipelines/architectures/qwen_image/qwen_image.py
# Conflicts:
#	max/examples/diffusion/simple_offline_generation.py
#	max/python/max/pipelines/core/context.py
Replace graph-level complex multiplication (~10 ops per call) with the
fused `rope_ragged_with_position_ids` GPU kernel for RoPE application
in QwenImage transformer blocks. This mirrors the approach already used
by flux2 and z-image.

With 60 dual-stream blocks applying RoPE to both Q and K (120 calls per
step), this eliminates significant kernel launch overhead and intermediate
memory allocations. Measured improvement: transformer step avg 100.6ms →
81.9ms (-18.5%).

Changes:
- QwenImagePosEmbed: return interleaved freqs_cis [S, D] instead of
  separate (cos, sin) tuple
- QwenImageAttention: use rope_ragged_with_position_ids kernel
- qwen_image.py: cast freqs_cis to model dtype (z-image pattern)
- Update parity test fixtures for new single-tensor interface
@byungchul-sqzb byungchul-sqzb requested a review from a team as a code owner April 3, 2026 05:34
@byungchul-sqzb byungchul-sqzb force-pushed the opt/qwen-image-rope-kernel branch 2 times, most recently from 56a7ace to 455c476 Compare April 3, 2026 08:57
@byungchul-sqzb byungchul-sqzb force-pushed the opt/qwen-image-rope-kernel branch from 455c476 to e2d9c8a Compare April 3, 2026 08:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants