Skip to content

Add conditional replacement of @torch.inference_mode for inference on AMD DirectML GPUs#3295

Open
deruyter92 wants to merge 9 commits into
mainfrom
jaap/amd_direct_ml_inference
Open

Add conditional replacement of @torch.inference_mode for inference on AMD DirectML GPUs#3295
deruyter92 wants to merge 9 commits into
mainfrom
jaap/amd_direct_ml_inference

Conversation

@deruyter92
Copy link
Copy Markdown
Collaborator

Motivation
Currently, our inference runners use @torch.inference_mode, which is not supported for AMD GPUs with DirectML inference mode. Essentially @torch.inference_mode is a stricter version of @torch.no_grad, which is newer and faster and works for most users. However, since it does not work for AMD DirectML users. It would be worthwhile to conditionally replace it with @torch.no_grad when necessary.

solves #3289

Changes
This PR replaces @torch.inference_mode with a conditional @_no_grad_decorator, controlled by the env variable DLC_DIRECTML_NO_GRAD. The decorator resolves to @torch.no_grad for DirectML users if they set the env var to "true", otherwise it defaults to @torch.inference_mode (keeping current behavior).

AMD GPUs with DirectML inference mode currently do not support torch.inference_mode, which is stricter than torch.no_grad.

This commit fixes that by adding a conditional `no_grad_decorator` which is controlled by the env variable DLC_DIRECTML_NO_GRAD. It resolves to either @torch.no_grad (if set "true") or @torch.inference_mode (default).
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an opt-in workaround for AMD DirectML inference by conditionally using torch.no_grad() instead of torch.inference_mode() in the PyTorch inference runners, controlled via an environment variable.

Changes:

  • Introduces DLC_DIRECTML_NO_GRAD env var parsing and a conditional decorator to select no_grad vs inference_mode.
  • Applies the conditional decorator to InferenceRunner.inference and CTDInferenceRunner.inference in place of @torch.inference_mode().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py Outdated
Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py
Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py
Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py Outdated
@deruyter92 deruyter92 marked this pull request as ready for review April 28, 2026 08:57
@deruyter92 deruyter92 requested a review from C-Achard April 28, 2026 08:57
Copy link
Copy Markdown
Collaborator

@C-Achard C-Achard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good fix, thanks!
Just minor comments, otherwise LGTM

Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py Outdated
Comment thread deeplabcut/pose_estimation_pytorch/runners/inference.py
Comment thread pyproject.toml Outdated
Copy link
Copy Markdown
Collaborator

@C-Achard C-Achard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the docs update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] DeepLabCut 3.0 PyTorch with AMD GPU (DirectML) fails on ConvTranspose2d in inference_mode — Workaround: use no_grad

3 participants