fix(gemma4): cast RoPE offset to int before mx.arange()#4901
fix(gemma4): cast RoPE offset to int before mx.arange()#4901danielhanchen merged 2 commits intounslothai:fix/ui-fixfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the position index generation in the Gemma4 text model by casting the offset to an integer. A review comment identifies that using int(offset) in MLX is inefficient because it triggers a CPU-GPU synchronization point and breaks compatibility with mx.compile. A suggestion was provided to use a zero-based range added to the offset to maintain performance and compilation support.
| # x shape: (B, n_heads, L, head_dim) | ||
| seq_len = x.shape[-2] | ||
| positions = mx.arange(offset, offset + seq_len, dtype = mx.float32) | ||
| positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32) |
There was a problem hiding this comment.
Using int(offset) is discouraged in MLX because it forces a synchronization point between the GPU and CPU to retrieve the value, which can significantly degrade performance during inference. Furthermore, if this code is executed within an mx.compile block, int(offset) will fail if offset is a tracer array.
A more efficient and compilation-friendly approach is to generate a zero-based range and then add the offset. This avoids the TypeError with mx.arange while supporting both integer and array-based offsets without performance penalties.
| positions = mx.arange(int(offset), int(offset) + seq_len, dtype = mx.float32) | |
| positions = mx.arange(seq_len, dtype = mx.float32) + offset |
There was a problem hiding this comment.
Good catch — updated the fix to use mx.arange(seq_len) + offset to avoid the CPU-GPU sync point and maintain mx.compile compatibility
|
@danielhanchen this fixes a TypeError crashing Gemma 4 inference for all users on the current fix/ui-fix branch — would appreciate a quick review 🙏 |
|
Thank you @eauchs - sorry on the delay! |
* unsloth gemma4 support files * some fixes * Fixing cache.empty() calls (#4813) * Fixing cache.empty() calls * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Fix/gemma4 mlx (#4816) * Fixing cache.empty() calls * fixing for mlx versions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * removed bidirectional check for 31b (#4839) Co-authored-by: Manan17 <shahmanan170602@gmail.coml> * Add Gemma 4 26B MoE support (MLX) (#4844) * removed bidirectional check for 31b * Change gemma4_text for moe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix(gemma4): cast RoPE offset to int before mx.arange() (#4901) * fix(gemma4): cast RoPE offset to int before mx.arange() * fix(gemma4): use zero-based arange + offset to avoid CPU-GPU sync * qwen3.6 patches for multi-turn chat * qwen3.6 script * removing unnecessary scripts * displaying errors for not installed packages --------- Co-authored-by: Roland Tannous <115670425+rolandtannous@users.noreply.github.com> Co-authored-by: Manan Shah <mananshah@Manans-MacBook-Pro.local> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Manan17 <shahmanan170602@gmail.coml> Co-authored-by: Théophile Lafargue <138336683+eauchs@users.noreply.github.com>
Problem
mx.arange()receives anmlx.core.arrayforoffsetinstead of aPython native int, causing a TypeError at inference time with Gemma 4 models.
Fix
Cast
offsettointbefore passing tomx.arange().Tested on
M3 Max 128GB — unsloth/gemma-4-31b-it-UD-MLX-4bit