Skip to content

Enable kernels-community/metal-flash-sdpa on MPS#45974

Open
ArthurZucker wants to merge 4 commits into
mainfrom
enable-metal-flash-sdpa-mps
Open

Enable kernels-community/metal-flash-sdpa on MPS#45974
ArthurZucker wants to merge 4 commits into
mainfrom
enable-metal-flash-sdpa-mps

Conversation

@ArthurZucker
Copy link
Copy Markdown
Collaborator

Enables kernels-community/metal-flash-sdpa for generate/generate_batch on MPS: synthesize cu_seqlens + .contiguous() in the no-padding branch of _flash_attention_forward, and fix MPS memory accounting in continuous batching.

Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch):

impl time tok/s acc
sdpa 149.33s 158.4 30/100
kernels-community/metal-flash-sdpa 89.78s 256.0 32/100

1.66× speedup, accuracy within noise.

Follow-up: push the contiguity/varlen handling into the kernel itself so modeling_flash_attention_utils.py no longer needs the fallback branch.

Two small fixes so `attn_implementation="kernels-community/metal-flash-sdpa"`
works end-to-end on Apple Silicon (`generate` and `generate_batch`):

* `modeling_flash_attention_utils._flash_attention_forward`: the "no padding"
  branch unconditionally called `flash_fn`, which is `None` for varlen-only
  kernels (the metal kernel only ships `flash_attn_varlen_func`). Synthesize
  `cu_seqlens` for the dense batched layout and route through `flash_varlen_fn`
  in that case. `.contiguous()` before reshape is required: the cached
  K/V (post-transpose) is non-contiguous and the Metal kernel reads garbage
  off it during decode, producing nonsense tokens.

* `continuous_batching/requests.get_device_and_memory_breakdown`: on MPS,
  `torch.mps.driver_allocated_memory()` returns bytes currently held by the
  Metal driver (≈0 right after process start), not the total. Use
  `recommended_max_memory()` for total and `current_allocated_memory()` for
  the running allocation. Without this, `infer_num_blocks_and_max_batch_tokens`
  either returns a negative `num_blocks` or refuses to allocate, so
  `generate_batch` was unusable on MPS regardless of the chosen attention.

Bench (gsm8k 100 samples, Qwen2.5-0.5B-Instruct, MPS fp16, generate_batch):

  impl                                       time(s)    tok/s    acc
  sdpa                                        149.33    158.4    30/100
  kernels-community/metal-flash-sdpa           89.78    256.0    32/100

1.66x speedup, accuracy within noise.
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker requested a review from remi-or May 14, 2026 08:10
When `transformers serve` runs on Apple Silicon (`--device auto` or `mps`)
with `kernels` installed and no explicit `--attn-implementation` flag, default
the attention to `kernels-community/metal-flash-sdpa` instead of plain SDPA.

On the 100-sample gsm8k benchmark (Qwen2.5-0.5B-Instruct, MPS fp16,
generate_batch) it's a 1.66x throughput improvement (158 -> 256 tok/s) with
token-for-token parity for greedy decoding. Users who don't want it can opt
out with `--attn-implementation sdpa`.

Help text on the `--attn-implementation` flag also now lists the kernels-hub
syntax explicitly.
Don't build cu_seqlens on the fly inside the modeling forward — the
non-padding `else` branch can stay as a NoneType failure for varlen-only
kernels. Callers that need varlen (continuous batching, padding-free
training) go through `paged_attention_forward` or the explicit
`cu_seq_lens_*` kwarg path, both of which already supply their own
cumulative lengths.

Companion kernel change: dropping `flash_attn_func` from
kernels-community/metal-flash-sdpa for the same reason (PR #3).
The published `main` of `kernels-community/metal-flash-sdpa` predates the MPS
dispatch hardening (contiguity, int32 cast, alias clone, MPS encoder flush)
that this integration depends on. Pinning to the open PR's HEAD commit so the
auto-default actually works end-to-end out of the box.

Drop / bump this constant when the upstream PR merges:
  https://huggingface.co/kernels-community/metal-flash-sdpa/discussions/3
Copy link
Copy Markdown
Collaborator

@remi-or remi-or left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please apply line length and fix the bug!

@@ -800,8 +800,36 @@ def _flash_attention_forward(

# No padding
else:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not elif ?

def _resolve_attn_implementation(cls, attn_implementation: str | None, device: str | int) -> str | None:
"""Auto-select a flash-attention kernel when the user didn't specify one.

On Apple Silicon (MPS) with ``kernels`` installed, default to
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think length line is 120, you should update you claude md!

# MPS memory reporting (PyTorch 2.0+). `driver_allocated_memory` returns bytes currently held by
# the Metal driver (≈ 0 right after process start), so use `recommended_max_memory` for total
# and `current_allocated_memory` for the running allocation instead.
total_memory = getattr(torch.mps, "recommended_max_memory")()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you forgot the default here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants