Skip to content

[MPS] GridSampler2D backward#179756

Closed
malfet wants to merge 7 commits intogh/malfet/828/basefrom
gh/malfet/828/head
Closed

[MPS] GridSampler2D backward#179756
malfet wants to merge 7 commits intogh/malfet/828/basefrom
gh/malfet/828/head

Conversation

@malfet
Copy link
Copy Markdown
Contributor

@malfet malfet commented Apr 8, 2026

Stack from ghstack (oldest at bottom):

Implements grid_sampler_2d_backward_mps as Metal kernel, ported from the CUDA implementation in GridSampler.cu / GridSampler.cuh. Supports all combinations of interpolation and padding modes for float32, float16, and bfloat16. Has been requested more than 10 times in #141287

Benchmark: grid_sampler_2d backward on Apple M5 Pro

Benchmark script
"""Benchmark for grid_sampler_2d backward on MPS."""
import time, torch

def bench(N, C, H, W, interp=0, dtype=torch.float32, memory_format=torch.contiguous_format,
          warmup=10, repeat=100):
    inp = torch.randn(N, C, H, W, device="mps", dtype=dtype)
    inp = inp.to(memory_format=memory_format).requires_grad_(True)
    grid = torch.randn(N, H, W, 2, device="mps", dtype=dtype)
    out = torch.grid_sampler_2d(inp, grid, interp, 0, True)
    grad = torch.randn_like(out)
    args = (grad, inp, grid, interp, 0, True, [True, True])
    for _ in range(warmup):
        torch.ops.aten.grid_sampler_2d_backward(*args)
    torch.mps.synchronize()
    times = []
    for _ in range(repeat):
        torch.mps.synchronize()
        t0 = time.perf_counter()
        torch.ops.aten.grid_sampler_2d_backward(*args)
        torch.mps.synchronize()
        times.append(time.perf_counter() - t0)
    times.sort()
    t = max(1, len(times) // 10)
    trimmed = times[t:-t]
    return sum(trimmed) / len(trimmed) * 1000

SHAPES = [(1,32,64,64), (1,16,256,256), (8,64,128,128), (16,128,256,256)]
INTERPS = {0: "bilinear", 2: "bicubic"}
FORMATS = {"contiguous": torch.contiguous_format, "channels_last": torch.channels_last}

print(f"{'dtype':<10} {'interp':<10} {'layout':<14} {'(N,C,H,W)':<24} {'ms':>8}")
print("-" * 70)
for dtype in [torch.float32, torch.float16]:
    for interp, iname in INTERPS.items():
        for fname, fmt in FORMATS.items():
            for shape in SHAPES:
                avg = bench(*shape, interp=interp, dtype=dtype, memory_format=fmt)
                print(f"{str(dtype):<10} {iname:<10} {fname:<14} {str(shape):<24} {avg:8.3f}")
dtype interp (N,C,H,W) contiguous (ms) channels_last (ms)
float32 bilinear (1, 32, 64, 64) 0.285 0.249
float32 bilinear (1, 16, 256, 256) 0.351 0.332
float32 bilinear (8, 64, 128, 128) 1.809 1.717
float32 bilinear (16, 128, 256, 256) 81.720 62.959
float32 bicubic (1, 32, 64, 64) 0.434 0.517
float32 bicubic (1, 16, 256, 256) 0.819 0.952
float32 bicubic (8, 64, 128, 128) 4.527 5.331
float32 bicubic (16, 128, 256, 256) 168.363 145.735
float16 bilinear (1, 32, 64, 64) 0.291 0.295
float16 bilinear (1, 16, 256, 256) 0.462 0.431
float16 bilinear (8, 64, 128, 128) 2.080 2.052
float16 bilinear (16, 128, 256, 256) 31.970 27.160
float16 bicubic (1, 32, 64, 64) 0.774 0.769
float16 bicubic (1, 16, 256, 256) 1.219 1.282
float16 bicubic (8, 64, 128, 128) 7.032 8.497
float16 bicubic (16, 128, 256, 256) 104.402 170.552

Notes:

  • Bilinear channels_last is up to 23% faster than contiguous (62.9 vs 81.7ms for the largest float32 case)
  • Bicubic channels_last regresses for float16 large tensors (170ms vs 104ms) due to the 4x4 spatial access pattern losing locality

Co-Authored-By: Claude Opus 4.6 noreply@anthropic.com

[ghstack-poisoned]
@malfet malfet requested a review from mruberry as a code owner April 8, 2026 21:01
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 8, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/179756

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 10 Pending

As of commit e70f946 with merge base 13d0e53 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category release notes: inductor (aoti) labels Apr 8, 2026
malfet added a commit that referenced this pull request Apr 8, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: bec1805
Pull-Request: #179756
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 8, 2026

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 8, 2026

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@malfet malfet marked this pull request as draft April 8, 2026 23:14
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 9, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: b1269f0
Pull-Request: #179756
@malfet malfet requested a review from kurtamohler April 9, 2026 05:36
@malfet malfet marked this pull request as ready for review April 9, 2026 05:36
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 9, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: 2ada53e
Pull-Request: #179756
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 9, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: 1d382c7
Pull-Request: #179756
@kurtamohler
Copy link
Copy Markdown
Collaborator

I suppose it's worth mentioning that #179388 was submitted to implement this already. But I'm guessing that this one might give better performance since it can calculate both grads in the same kernel

@malfet
Copy link
Copy Markdown
Contributor Author

malfet commented Apr 9, 2026

I suppose it's worth mentioning that #179388 was submitted to implement this already. But I'm guessing that this one might give better performance since it can calculate both grads in the same kernel

I waited for the author to fix numeric for some time (even in their original #159421 issue), but it never happened

Comment thread aten/src/ATen/native/mps/kernels/GridSampler.metal Outdated
@malfet malfet changed the title [WIP][MPS] GridSampler2D backward [MPS] GridSampler2D backward Apr 9, 2026

// Bilinear backward kernel
template <typename Pad, typename T>
kernel void grid_sampler_2d_backward_bilinear(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to combine the bilinear, bicubic, and nearest kernels into one, with a function pointer template parameter to switch between them, since there are a lot of identical lines of code in the three kernels. But maybe it's not worth the trouble, if the function pointer would need a huge list of arguments

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested which one is more performant? I know you pass mode as an argument to grid_sampler_3d, but for some reason CUDA follows this pattern for both interpolate and 2d... Template argument do sound reasonable to me

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can benchmark it as followup PR. This one is a literal transpiling of CUDA shader into Metal without significant architectural changes. But if selecting interpolation mode dynamically and statically (via kernel dispatch) have the same perf, it indeed stands to reason to use it

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tested the performance. I had assumed that the compiler would make them the same, but I agree that there is a chance that it doesn't

@mlaves
Copy link
Copy Markdown
Contributor

mlaves commented Apr 10, 2026

I suppose it's worth mentioning that #179388 was submitted to implement this already. But I'm guessing that this one might give better performance since it can calculate both grads in the same kernel

I waited for the author to fix numeric for some time (even in their original #159421 issue), but it never happened

Hi, thanks for opening this PR. My original PR #159421 did not have any numerical issues and all tests were passing. I removed the GridSampler2D backward from my current PR #179388 to avoid conflicts with this PR.

Comment thread aten/src/ATen/native/mps/operations/GridSampler.mm Outdated
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 15, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: 2bfb259
Pull-Request: #179756
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 15, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: 6484abd
Pull-Request: #179756
[ghstack-poisoned]
malfet added a commit that referenced this pull request Apr 15, 2026
Which has been requested more than 10 times

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
ghstack-source-id: 95204a5
Pull-Request: #179756
@malfet
Copy link
Copy Markdown
Contributor Author

malfet commented Apr 15, 2026

@pytorchbot merge -f "Lint + MPS are green"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Apr 19, 2026
Implement the MPS backward pass for `grid_sampler_3d`, supporting bilinear and nearest interpolation across `float32`, `float16`, and `bfloat16` dtypes.

The backward kernels use `atomic<float>` for grad_input accumulation (Metal 3 lacks `atomic<half>`/`atomic<bfloat>`), with automatic dtype conversion on output. Intermediate computations use `opmath_t<T>` to maintain precision for reduced-precision types.

The forward kernel is restructured to match the 2D kernel pattern: one thread per spatial position (`N*D*H*W`) with a channel loop, eliminating redundant grid reads and coordinate transforms across channels. Nearest-neighbor interpolation is also added for the 3D forward pass.

The `grid_sampler_2d` backward has been removed from this PR since
#179756 implements it independently. This PR is designed to rebase cleanly on top of that change.

This is a reactivation of my stale PR #159421. @Skylion007 @kurtamohler
Pull Request resolved: #179388
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants