Skip to content

expose num_splits in FA2 + bump flash attention submodule commit#179760

Closed
liangel-02 wants to merge 9 commits intogh/liangel-02/42/basefrom
gh/liangel-02/42/head
Closed

expose num_splits in FA2 + bump flash attention submodule commit#179760
liangel-02 wants to merge 9 commits intogh/liangel-02/42/basefrom
gh/liangel-02/42/head

Conversation

@liangel-02
Copy link
Copy Markdown
Contributor

@liangel-02 liangel-02 commented Apr 8, 2026

enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test.

since this involves changes in upstream flash attention repo, also bumping submodule commit

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179760

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (1 Unrelated Failure)

As of commit aca36e4 with merge base 420f50f (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

liangel-02 added a commit that referenced this pull request Apr 8, 2026
ghstack-source-id: a6a44d7
Pull Request resolved: #179760
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 8, 2026

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

liangel-02 added a commit that referenced this pull request Apr 8, 2026
ghstack-source-id: ea1e3de
Pull Request resolved: #179760
liangel-02 added a commit that referenced this pull request Apr 8, 2026
ghstack-source-id: 5a5621a
Pull Request resolved: #179760
@liangel-02 liangel-02 requested a review from drisspg April 9, 2026 16:17
@liangel-02 liangel-02 added the release notes: nn release notes category label Apr 9, 2026
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 16, 2026
ghstack-source-id: 218dbd2
Pull Request resolved: #179760
@liangel-02 liangel-02 added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 16, 2026
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Apr 16, 2026

@claude Please do a deep review focused on subtle correctness bugs and regressions that are easy for human reviewers to miss. Prioritize findings over summary. I dont care about anything that isnt a real bug drop all the noise but that, if you there isnt anything to report just say so with a small comment

@claude
Copy link
Copy Markdown

claude Bot commented Apr 16, 2026

Claude finished @drisspg's task in 3m 34s —— View job


  • Read full diff and changed files
  • Analyze correctness of num_splits plumbing
  • Analyze block size alignment logic
  • Check test correctness
  • Report findings

Findings

I didn't find any subtle correctness bugs. The num_splits plumbing is straightforward and the default value_or(0) preserves the old behavior. The test using torch.equal for bit-identical checks under num_splits=1 is correct.

One behavioral note (not a bug): the condition change from if (seqlenq_ngroups_swapped) to if (paged_KV || seqlenq_ngroups_swapped) at flash_api.cpp:767 means set_params_splitkv is now called for paged KV prefill (multi-token seqlen_q > 1) even with the default num_splits=0. The heuristic in set_params_splitkv can select num_splits > 1, which changes execution for existing paged KV workloads that don't pass num_splits. Before this PR, those workloads had params.num_splits=0 when hitting the splitkv kernel. This seems intentional based on the PR goals but worth confirming it doesn't regress any existing paged KV benchmarks.

enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 20, 2026
ghstack-source-id: 1581393
Pull Request resolved: #179760
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 20, 2026
ghstack-source-id: 2486c0c
Pull Request resolved: #179760
@liangel-02
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

This PR updates submodules third_party/flash-attention

If those updates are intentional, please add "submodule" keyword to PR title/description.

@liangel-02 liangel-02 changed the title expose num_splits in FA2 expose num_splits in FA2 + bump flash attention submodule commit Apr 21, 2026
@liangel-02
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: trunk / win-vs2022-cuda13.0-py3 / build, trunk / win-vs2022-cpu-py3 / test (default, 1, 4, windows.4xlarge.nonephemeral), trunk / win-vs2022-cpu-py3 / test (default, 4, 4, windows.4xlarge.nonephemeral)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Apr 21, 2026
…mit (#179760)"

This reverts commit b2e1f52.

Reverted #179760 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/test_varlen_attention.py::TestVarlenAttentionCUDA::test_batch_invariance_float16_num_splits_1_window_size3_backend_fa2_cuda_float16 [GH job link](https://github.com/pytorch/pytorch/actions/runs/24696802211/job/72258230939) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/b2e1f526d2c562e3a3406aff63a61cb0370155fb) ([comment](#179760 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@liangel-02 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 21, 2026
…commit"

enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 

since this involves changes in upstream flash attention repo, also bumping submodule commit




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 21, 2026
ghstack-source-id: 79d9db1
Pull Request resolved: #179760
…commit"

enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 

since this involves changes in upstream flash attention repo, also bumping submodule commit




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 21, 2026
ghstack-source-id: 66faf2c
Pull Request resolved: #179760
@liangel-02 liangel-02 added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 and removed Merged Reverted ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/rocm-mi200 Trigger "default" config CI on ROCm MI200 labels Apr 22, 2026
…commit"

enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. 

since this involves changes in upstream flash attention repo, also bumping submodule commit




[ghstack-poisoned]
liangel-02 added a commit that referenced this pull request Apr 22, 2026
ghstack-source-id: 13d1e79
Pull Request resolved: #179760
@pytorch pytorch deleted a comment from pytorch-bot Bot Apr 22, 2026
@liangel-02
Copy link
Copy Markdown
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / linux-jammy-rocm-py3.10-mi355 / test (default, 5, 6, linux.rocm.gpu.gfx950.1)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: nn release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants