Add conditional replacement of @torch.inference_mode for inference on AMD DirectML GPUs#3295
Open
deruyter92 wants to merge 9 commits into
Open
Add conditional replacement of @torch.inference_mode for inference on AMD DirectML GPUs#3295deruyter92 wants to merge 9 commits into
@torch.inference_mode for inference on AMD DirectML GPUs#3295deruyter92 wants to merge 9 commits into
Conversation
AMD GPUs with DirectML inference mode currently do not support torch.inference_mode, which is stricter than torch.no_grad. This commit fixes that by adding a conditional `no_grad_decorator` which is controlled by the env variable DLC_DIRECTML_NO_GRAD. It resolves to either @torch.no_grad (if set "true") or @torch.inference_mode (default).
Contributor
There was a problem hiding this comment.
Pull request overview
Adds an opt-in workaround for AMD DirectML inference by conditionally using torch.no_grad() instead of torch.inference_mode() in the PyTorch inference runners, controlled via an environment variable.
Changes:
- Introduces
DLC_DIRECTML_NO_GRADenv var parsing and a conditional decorator to selectno_gradvsinference_mode. - Applies the conditional decorator to
InferenceRunner.inferenceandCTDInferenceRunner.inferencein place of@torch.inference_mode().
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
2 tasks
C-Achard
approved these changes
Apr 28, 2026
Collaborator
C-Achard
left a comment
There was a problem hiding this comment.
Good fix, thanks!
Just minor comments, otherwise LGTM
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
C-Achard
approved these changes
Apr 28, 2026
Collaborator
C-Achard
left a comment
There was a problem hiding this comment.
Thanks for the docs update!
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Currently, our inference runners use
@torch.inference_mode, which is not supported for AMD GPUs with DirectML inference mode. Essentially@torch.inference_modeis a stricter version of@torch.no_grad, which is newer and faster and works for most users. However, since it does not work for AMD DirectML users. It would be worthwhile to conditionally replace it with@torch.no_gradwhen necessary.solves #3289
Changes
This PR replaces
@torch.inference_modewith a conditional@_no_grad_decorator, controlled by the env variableDLC_DIRECTML_NO_GRAD. The decorator resolves to@torch.no_gradfor DirectML users if they set the env var to "true", otherwise it defaults to@torch.inference_mode(keeping current behavior).