expose num_splits in FA2 + bump flash attention submodule commit#179760
expose num_splits in FA2 + bump flash attention submodule commit#179760liangel-02 wants to merge 9 commits intogh/liangel-02/42/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179760
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ You can merge normally! (1 Unrelated Failure)As of commit aca36e4 with merge base 420f50f ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
[ghstack-poisoned]
[ghstack-poisoned]
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. [ghstack-poisoned]
|
@claude Please do a deep review focused on subtle correctness bugs and regressions that are easy for human reviewers to miss. Prioritize findings over summary. I dont care about anything that isnt a real bug drop all the noise but that, if you there isnt anything to report just say so with a small comment |
|
Claude finished @drisspg's task in 3m 34s —— View job
FindingsI didn't find any subtle correctness bugs. The One behavioral note (not a bug): the condition change from |
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. [ghstack-poisoned]
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. [ghstack-poisoned]
|
@pytorchbot merge -i |
|
This PR updates submodules third_party/flash-attention If those updates are intentional, please add "submodule" keyword to PR title/description. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: trunk / win-vs2022-cuda13.0-py3 / build, trunk / win-vs2022-cpu-py3 / test (default, 1, 4, windows.4xlarge.nonephemeral), trunk / win-vs2022-cpu-py3 / test (default, 4, 4, windows.4xlarge.nonephemeral) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot successfully started a revert job. Check the current status here. |
…mit (#179760)" This reverts commit b2e1f52. Reverted #179760 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/test_varlen_attention.py::TestVarlenAttentionCUDA::test_batch_invariance_float16_num_splits_1_window_size3_backend_fa2_cuda_float16 [GH job link](https://github.com/pytorch/pytorch/actions/runs/24696802211/job/72258230939) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/b2e1f526d2c562e3a3406aff63a61cb0370155fb) ([comment](#179760 (comment)))
|
@liangel-02 your PR has been successfully reverted. |
…commit" enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. since this involves changes in upstream flash attention repo, also bumping submodule commit [ghstack-poisoned]
…commit" enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. since this involves changes in upstream flash attention repo, also bumping submodule commit [ghstack-poisoned]
…commit" enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test. since this involves changes in upstream flash attention repo, also bumping submodule commit [ghstack-poisoned]
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / linux-jammy-rocm-py3.10-mi355 / test (default, 5, 6, linux.rocm.gpu.gfx950.1) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
enable num_splits param for FA2 and if num_splits=1 then for paged KV align block size to match with standard kernel. for testing i used torch.equal() to check numerics in the existing paging test.
since this involves changes in upstream flash attention repo, also bumping submodule commit
Stack from ghstack (oldest at bottom):