diff --git a/deeplabcut/pose_estimation_pytorch/runners/inference.py b/deeplabcut/pose_estimation_pytorch/runners/inference.py index 55f1c531c..367685bb3 100644 --- a/deeplabcut/pose_estimation_pytorch/runners/inference.py +++ b/deeplabcut/pose_estimation_pytorch/runners/inference.py @@ -10,10 +10,12 @@ # from __future__ import annotations +import os import threading import warnings from abc import ABCMeta, abstractmethod from collections.abc import Iterable +from contextlib import contextmanager from dataclasses import asdict, dataclass, field from pathlib import Path from queue import Empty, Full, Queue @@ -39,6 +41,39 @@ ) from deeplabcut.pose_estimation_pytorch.task import Task +# NOTE @deruyter92 2026-04-28: AMD GPUs with DirectML inference mode currently do not +# support torch.inference_mode, which is stricter than torch.no_grad. The ENV +# variable is used to conditionally use torch.no_grad instead. See PR #3295. +_directml_no_grad: bool = os.getenv("DLC_DIRECTML_NO_GRAD", "false").lower() in ( + "true", + "1", +) + + +def _inference_mode_decorator(fn): + """ + Conditional decorator for inference mode, controlled by the DLC_DIRECTML_NO_GRAD ENV variable. + Uses @torch.no_grad if set to "true", otherwise defaults to @torch.inference_mode. + """ + return torch.no_grad()(fn) if _directml_no_grad else torch.inference_mode()(fn) + + +@contextmanager +def _directml_runtime_error_hint(): + """Context manager that augments runtime errors with a hint for DirectML-related issues.""" + try: + yield + except RuntimeError as e: + if torch.is_inference_mode_enabled() and not _directml_no_grad: + raise RuntimeError( + f"{e}\n\n" + "If you are using an AMD GPU with DirectML, this error may be caused by " + "@torch.inference_mode being incompatible with the DirectML execution path. " + "Try setting the environment variable DLC_DIRECTML_NO_GRAD=true, " + "which will switch the inference context to @torch.no_grad." + ) from e + raise + def _merge_defaults(cls, data: dict[str, Any]): """ @@ -268,7 +303,7 @@ def predict(self, inputs: torch.Tensor, **kwargs) -> list[dict[str, dict[str, np the predictions for each of the 'batch_size' inputs """ - @torch.inference_mode() + @_inference_mode_decorator def inference( self, images: (Iterable[str | Path | np.ndarray] | Iterable[tuple[str | Path | np.ndarray, dict[str, Any]]]), @@ -353,7 +388,8 @@ def _async_inference( batch, model_kwargs = item # Run model inference - predictions = self.predict(batch, **model_kwargs) + with _directml_runtime_error_hint(): + predictions = self.predict(batch, **model_kwargs) self._predictions.extend(predictions) # Extract and return results @@ -463,7 +499,8 @@ def _process_batch(self) -> None: batch = torch.stack(self._batch_list[: self.batch_size], dim=0) model_kwargs = {mk: v[: self.batch_size] for mk, v in self._model_kwargs.items()} - self._predictions += self.predict(batch, **model_kwargs) + with _directml_runtime_error_hint(): + self._predictions += self.predict(batch, **model_kwargs) # remove processed inputs if len(self._batch_list) <= self.batch_size: @@ -648,7 +685,7 @@ def __init__( self._idx_to_id = None self._ctd_track_ages = None # the age of each CTD tracklet - @torch.inference_mode() + @_inference_mode_decorator def inference( self, images: (Iterable[str | Path | np.ndarray] | Iterable[tuple[str | Path | np.ndarray, dict[str, Any]]]), @@ -752,7 +789,8 @@ def add_conditions( inputs = torch.as_tensor(image) # Get and post-process the predictions - predictions = self.bu_runner.predict(inputs) + with _directml_runtime_error_hint(): + predictions = self.bu_runner.predict(inputs) if self.bu_runner.postprocessor is not None: predictions, context = self.bu_runner.postprocessor(predictions, context) @@ -775,7 +813,8 @@ def _ctd_tracking_inference( for data in images: inputs, context = self._prepare_ctd_inputs(data) model_kwargs = context.pop("model_kwargs", {}) - predictions = self.predict(inputs, **model_kwargs) + with _directml_runtime_error_hint(): + predictions = self.predict(inputs, **model_kwargs) if self.postprocessor is not None: # Pop the "cond_kpts" from the context so there's no re-scoring # This is required when tracking with CTD, otherwise scores go to 0 diff --git a/docs/recipes/TechHardware.md b/docs/recipes/TechHardware.md index 699d186c2..9ab75ade2 100644 --- a/docs/recipes/TechHardware.md +++ b/docs/recipes/TechHardware.md @@ -19,6 +19,13 @@ For reference, we use e.g. Dell workstations (79xx series) with **Ubuntu 16.04 L Ideally, you will use a strong GPU with *at least* 8GB memory such as the [NVIDIA GeForce 1080 Ti, 2080 Ti, or 3090](https://marketplace.nvidia.com/en-us/consumer/graphics-cards/). A GPU is not strictly necessary, but on a CPU the (training and evaluation) code is considerably slower (10x) for ResNets, but MobileNets and EfficientNets are slightly faster. Still, a GPU will give you a massive speed boost. You might also consider using cloud computing services like [Google cloud/amazon web services](https://github.com/DeepLabCut/DeepLabCut/issues/47) or Google Colaboratory. +```{note} +If you encounter errors during inference related to +`torch.inference_mode` and DirectML, set the environment variable +`DLC_DIRECTML_NO_GRAD=true` before starting Python. This switches the inference +context to `torch.no_grad`, which is compatible with the DirectML execution path. +``` + ### Camera Hardware: The software is very robust to track data from any camera (cell phone cameras, grayscale, color; captured under infrared light, different manufacturers, etc.). See demos on our [website](https://www.mousemotorlab.org/deeplabcut/). diff --git a/tests/pose_estimation_pytorch/runners/test_inference_directml_no_grad.py b/tests/pose_estimation_pytorch/runners/test_inference_directml_no_grad.py new file mode 100644 index 000000000..318c909e3 --- /dev/null +++ b/tests/pose_estimation_pytorch/runners/test_inference_directml_no_grad.py @@ -0,0 +1,67 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Tests DLC_DIRECTML_NO_GRAD toggles inference_mode vs no_grad (AMD DirectML).""" + +from __future__ import annotations + +import importlib +import os +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +import deeplabcut.pose_estimation_pytorch.runners.inference as inference + + +def _reload_with_env(env_value: str | None): + if env_value is None: + os.environ.pop("DLC_DIRECTML_NO_GRAD", None) + else: + os.environ["DLC_DIRECTML_NO_GRAD"] = env_value + importlib.reload(inference) + + +@pytest.fixture(autouse=True) +def _restore_env(): + yield + _reload_with_env(None) # always restore defaults after each test + + +@pytest.mark.parametrize( + ("env_value", "directml_no_grad"), + [(None, False), ("false", False), ("true", True)], +) +def test_directml_no_grad_env(env_value, directml_no_grad): + """env var sets _directml_no_grad and selects the correct torch grad context.""" + _reload_with_env(env_value) + assert inference._directml_no_grad is directml_no_grad + + class _SniffRunner(inference.InferenceRunner): + def __init__(self): + super().__init__( + model=Mock(), + batch_size=1, + inference_cfg=inference.InferenceConfig( + multithreading=inference.MultithreadingConfig(enabled=False), + ), + ) + self.saw_inference_mode: bool | None = None + + def predict(self, inputs: torch.Tensor, **kwargs): + self.saw_inference_mode = torch.is_inference_mode_enabled() + return [{"mock": {"poses": np.zeros((1,), dtype=np.float32)}}] + + runner = _SniffRunner() + runner.inference([np.zeros((1, 3, 8, 8), dtype=np.float32)]) + + assert runner.saw_inference_mode is not directml_no_grad