[MAX] Optimize QwenImage DiT RoPE: graph ops → fused GPU kernel #6345
Open
byungchul-sqzb wants to merge 13 commits intomodular:mainfrom
Open
[MAX] Optimize QwenImage DiT RoPE: graph ops → fused GPU kernel #6345byungchul-sqzb wants to merge 13 commits intomodular:mainfrom
byungchul-sqzb wants to merge 13 commits intomodular:mainfrom
Conversation
# Conflicts: # max/python/max/pipelines/architectures/qwen_image/layers/qwen_image_attention.py # max/python/max/pipelines/architectures/qwen_image/qwen_image.py
# Conflicts: # max/examples/diffusion/simple_offline_generation.py # max/python/max/pipelines/core/context.py
Replace graph-level complex multiplication (~10 ops per call) with the fused `rope_ragged_with_position_ids` GPU kernel for RoPE application in QwenImage transformer blocks. This mirrors the approach already used by flux2 and z-image. With 60 dual-stream blocks applying RoPE to both Q and K (120 calls per step), this eliminates significant kernel launch overhead and intermediate memory allocations. Measured improvement: transformer step avg 100.6ms → 81.9ms (-18.5%). Changes: - QwenImagePosEmbed: return interleaved freqs_cis [S, D] instead of separate (cos, sin) tuple - QwenImageAttention: use rope_ragged_with_position_ids kernel - qwen_image.py: cast freqs_cis to model dtype (z-image pattern) - Update parity test fixtures for new single-tensor interface
56a7ace to
455c476
Compare
455c476 to
e2d9c8a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Important
The core changes of this optimization are contained entirely within commit: e72d459.
ONLY 7 files changed (+82 -50 lines changed)
Summary
Optimize QwenImage DiT RoPE by replacing graph-level complex multiplication with the fused
rope_ragged_with_position_idsGPU kernel.The existing
apply_rotary_embimplementation performs ~10 graph ops per call (reshape, split, cast, 4x elementwise multiply, stack, reshape, cast back). This PR replaces it with a single fused Mojo GPU kernel (mo.rope.ragged.with_position_id), following the same pattern already used by flux2 and z-image.With 60 dual-stream blocks applying RoPE to both Q and K (120 kernel invocations per denoising step), this eliminates significant kernel launch overhead and intermediate memory allocations.
Measured improvement:
qwen_image(text-to-image): transformer step avg 100.6ms → 81.9msqwen_image_edit(image editing): transformer step avg 238.4ms → 178.0msChanges:
QwenImagePosEmbed: return interleavedfreqs_cistensor[S, D]instead of separate(cos, sin)tupleQwenImageAttention: userope_ragged_with_position_idskernel via new_apply_qwen_image_qk_ropehelperqwen_image.py: castfreqs_cisto model dtype (matching z-image pattern)Both
qwen_imageandqwen_image_editbenefit from this change sinceqwen_image_editreusesQwenImageTransformer2DModeldirectly.Note on Dependencies
This PR is developed and tested on a local environment that assumes all initial Qwen-Image implementation PRs (#6139 #6140 and #6141 ) have been merged. Since those are still pending, I plan to rebase this PR once the base implementations are landed in main.
Testing
./bazelw run //:format— all checks passedqwen_imageand 80-step onqwen_image_edit, confirming 18.5% transformer speedup on both pipelines with no change in output quality./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "Qwen/Qwen-Image-2512" --prompt "dog dancing near the sun" --negative-prompt " " --num-inference-steps 50 --guidance-scale 4.0 --true-cfg-scale 4.0 --seed 42 --profile-timings --num-profile-iterations 1 --num-warmups 1-Qwen-Image-Edit
./bazelw run //max/examples/diffusion:simple_offline_generation -- --model "Qwen/Qwen-Image-Edit-2511" --prompt "use soft lighting to relight the image." --negative-prompt " " --num-inference-steps 40 --guidance-scale 4.0 --true-cfg-scale 4.0 --seed 42 --profile-timings --num-profile-iterations 1 --num-warmups 1 --input-image "../input.jpeg"Checklist
sequence of smaller PRs
./bazelw run formatto format my changesAssisted-by:trailer in my commit message or this PR description(see AI Tool Use Policy)
Assisted-by: AI