Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
cc0c8e3
unsilence DLCDeprecationWarning: show once per instance
deruyter92 Jun 25, 2026
474f2df
add modules deeplabcut.api and deeplabcu.tensorflow_compat (init only)
deruyter92 Jun 25, 2026
88b2f24
add tensorflow routing for deprecated tensorflow API
deruyter92 Jun 25, 2026
0889c1c
copy compat into tensorflow_compat/tensorflow_api
deruyter92 Jun 25, 2026
0dc7fd6
update _tf_routing: resolve deprecated arguments
deruyter92 Jun 25, 2026
1da630e
update _tf_routing.py: improve engine resolution
deruyter92 Jun 25, 2026
bb876c5
copy compat into api/pose_estimation
deruyter92 Jun 25, 2026
2e4892f
remove docstings in pose_estimation.py
deruyter92 Jun 25, 2026
a5f7bd1
temp clean pose_estimation_api
deruyter92 Jun 25, 2026
f75c575
update _tf_routing: require dropped, renamed params, and normalize gp…
deruyter92 Jun 25, 2026
2533741
add canoncical pose estimation API with TF fallback
deruyter92 Jun 25, 2026
b160f02
update tensorflow API in tensorflow_compat
deruyter92 Jun 25, 2026
b986676
fix Engine imports across repo (core, not compat)
deruyter92 Jun 26, 2026
314ed76
hook up deeplabcut.api.pose_estimation instead of compat (main init +…
deruyter92 Jun 26, 2026
542cae8
add tests for _tf_routing
deruyter92 Jun 26, 2026
a9ec700
generate_training_dataset: update Engine import (until not needed)
deruyter92 Jun 26, 2026
dfe435c
remove TensorFlow branches in generate_training_dataset
deruyter92 Jun 26, 2026
11d1c21
copy TensorFlow-only branches of generate_training_dataset to tensor…
deruyter92 Jun 26, 2026
b5391ff
remove engine parameter in generate_training_dataset (and in the tens…
deruyter92 Jun 26, 2026
b9afbe3
adjust tf_routing: import top-level tensorflow_compat
deruyter92 Jun 26, 2026
3cad8ef
fix rename_parameter order in evaluate_network API
deruyter92 Jun 26, 2026
dfa6897
hook up data_management / generate_training_dataset API
deruyter92 Jun 26, 2026
6ce95c3
deeplabcut.api.pose_estimation fix incorrect dropped parameter destfo…
deruyter92 Jun 26, 2026
d0f3e7e
move get_available_aug_methods to engine
deruyter92 Jun 26, 2026
483d373
add smoke tests for split pose_estimation API
deruyter92 Jun 26, 2026
902613d
add tests for return_network_path and visualization
deruyter92 Jun 26, 2026
4da05d6
Revert generate_training_dataset migration (simplify PR)
deruyter92 Jun 26, 2026
09f7174
rename tensorflow_compat tensorflow_api to pose_estimation
deruyter92 Jun 26, 2026
1edf3ac
add missing from_future import
deruyter92 Jun 26, 2026
dc50698
Fix _tf_routing: allow use of legacy capitalized `Shuffles` for engin…
deruyter92 Jun 26, 2026
1c58ecf
rename tests to avoid naming colision
deruyter92 Jun 26, 2026
864c7b5
Merge remote-tracking branch 'upstream/main' into jaap/prepare_tf_dep…
deruyter92 Jun 26, 2026
063a84e
fix merge conflix: import deprecation from core
deruyter92 Jun 26, 2026
d3f8488
remove deeplabcut.compat in favor of deeplabcut.api
deruyter92 Jun 26, 2026
a0ab176
lazy import of get_shuffle_engine
deruyter92 Jun 26, 2026
f35e425
fix mock in test_tf_routing
deruyter92 Jun 26, 2026
d0635ca
add GUI warning when selecting TensorFlow
deruyter92 Jun 26, 2026
95065d9
update _tf_routing allow arg config if absent from kwargs
deruyter92 Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions deeplabcut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,19 @@
if DEBUG:
logger.debug("Loading DLC %s", VERSION)

# DeepLabCut deprecation warnings are shown only once per message instance.
import warnings

from deeplabcut.core.deprecation import DLCDeprecationWarning

warnings.filterwarnings("once", category=DLCDeprecationWarning)

# -----------------------------------------------------------------------------
# Always-available public API
# -----------------------------------------------------------------------------

# Train / evaluate / predict functions (compat layer)
from .compat import (
# Train / evaluate / predict functions
from .api.pose_estimation import (
analyze_images,
analyze_time_lapse_frames,
analyze_videos,
Expand Down
44 changes: 44 additions & 0 deletions deeplabcut/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
"""
Public API for DeepLabCut.
"""

from __future__ import annotations

from typing import Any

__all__ = [
"analyze_images",
"analyze_time_lapse_frames",
"analyze_videos",
"convert_detections2tracklets",
"create_tracking_dataset",
"evaluate_network",
"export_model",
"extract_maps",
"extract_save_all_maps",
"return_evaluate_network_data",
"return_train_network_path",
"train_network",
"visualize_locrefs",
"visualize_paf",
"visualize_scoremaps",
]


def __getattr__(name: str) -> Any:
if name not in __all__:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

from deeplabcut.api import pose_estimation

return getattr(pose_estimation, name)
180 changes: 180 additions & 0 deletions deeplabcut/api/_tf_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
"""
Routing for legacy TensorFlow API while still supported. Remove this module when TF support is dropped.
"""

import warnings
from collections.abc import Callable
from functools import lru_cache
from importlib import import_module

from deeplabcut.core.deprecation import DLCDeprecationWarning
from deeplabcut.core.engine import Engine
from deeplabcut.utils.auxiliaryfunctions import read_config

_TF_MODULE = "deeplabcut.tensorflow_compat"


@lru_cache
def _get_tensorflow_impl(name: str):
return getattr(import_module(_TF_MODULE), name)


def warn_deprecated_tensorflow():
warnings.warn(
"\n"
"━" * 60 + "\n"
"⚠️ DeepLabCut — TensorFlow support is deprecated\n"
"━" * 60 + "\n"
"TensorFlow support will be removed in a future release.\n"
"Your project config and annotated data are fully compatible with PyTorch.\n"
"Please run create_training_dataset with any PyTorch model architecture to switch to PyTorch.\n"
"See our docs for more information: https://deeplabcut.github.io/DeepLabCut/docs/pytorch/architectures.html\n"
"━" * 60,
DLCDeprecationWarning,
stacklevel=3,
)


def with_tensorflow_fallback(
_fn: Callable | None = None,
*,
tensorflow_name: str | None = None,
renamed_params: dict[str, str] | None = None,
dropped_params: list[str] | None = None,
normalize_gputouse: bool = False,
) -> Callable:
"""Decorator for wrapping canonical PyTorch API functions, routing to a fallback TF function if required.
It automatically resolves the engine and converts legacy TensorFlow kwargs to canonical PyTorch kwargs, if needed.
Can be used with or without parentheses.

Args:
tensorflow_name (str | None): The name of the fallback TensorFlow function in ``_TF_MODULE``. If not specified,
uses the name of the canonical PyTorch function.
renamed_params (dict[str, str] | None): Optional mapping from old TF parameter names to the new canonical
PyTorch names. A warning will be emitted and the value is passed under the new canonical name. If both the
old and new names are specified, raises a TypeError.
dropped_params (list[str] | None): Optional list of dropped TensorFlow parameter names. A warning will be
emitted and the value is ignored.
normalize_gputouse (bool): resolve the old TF ``gputouse`` parameter to the new canonical PyTorch ``device``
parameter. Raises a TypeError if both are specified.

Note:
The engine is resolved from the shuffle metadata if not specified explicitly. If neither ``shuffles``,
``shuffle`` or ``engine`` is passed, it assumes shuffle=1.
"""

def decorator(fn):
tf_name = tensorflow_name or fn.__name__

def wrapper(*args, **kwargs):
engine = _resolve_engine(*args, **kwargs)
kwargs.pop("engine", None)
if engine == Engine.TF:
warn_deprecated_tensorflow()
return _get_tensorflow_impl(tf_name)(*args, **kwargs)
kwargs = _resolve_legacy_kwargs(
kwargs,
renamed_params=renamed_params or {},
dropped_params=dropped_params or [],
normalize_gputouse=normalize_gputouse,
)
return fn(*args, **kwargs)

return wrapper

if _fn is not None:
return decorator(_fn)
return decorator


def _shuffles_from_kwargs(kwargs: dict) -> list | tuple:
"""Return shuffle indices from kwargs, accepting legacy ``Shuffles``."""
if "shuffles" in kwargs and "Shuffles" in kwargs:
raise TypeError("Cannot specify both 'Shuffles' (deprecated) and 'shuffles'. Use 'shuffles' only.")
if "shuffles" in kwargs:
return kwargs["shuffles"]
if "Shuffles" in kwargs:
return kwargs["Shuffles"]
return [kwargs.get("shuffle", 1)]


def _resolve_engine(*args, **kwargs) -> Engine:
"""Resolve engine from explicit engine parameter or shuffle metadata."""
engine = kwargs.get("engine")
if engine is not None:
return engine

shuffles = _shuffles_from_kwargs(kwargs)
config = kwargs["config"] if "config" in kwargs else args[0]
cfg = read_config(config)
from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine

engines = {
get_shuffle_engine(
cfg,
trainingsetindex=kwargs.get("trainingsetindex", 0),
shuffle=s,
modelprefix=kwargs.get("modelprefix", ""),
)
for s in shuffles
}
if len(engines) > 1:
raise ValueError(f"All shuffles must have the same engine (found different engines for shuffles: {shuffles}).")
return engines.pop()


def _normalize_gputouse(gputouse: str | int) -> str:
if isinstance(gputouse, int):
return f"cuda:{gputouse}"
if gputouse.startswith("cuda:"):
return gputouse
if gputouse.startswith("gpu:"):
return gputouse.replace("gpu:", "cuda:")
return gputouse


def _resolve_legacy_kwargs(
kwargs: dict,
renamed_params: dict[str, str],
dropped_params: list[str],
normalize_gputouse: bool = False,
) -> dict:
"""Resolve legacy TensorFlow kwargs to canonical (PyTorch) kwargs."""

if normalize_gputouse and (gpu := kwargs.get("gputouse")):
# Normalize parameter "gputouse" to torch device string and rename
kwargs["gputouse"] = _normalize_gputouse(gpu)
renamed_params["gputouse"] = "device"

# Rename deprecated parameters
for old, new in renamed_params.items():
if old in kwargs:
if new in kwargs:
raise TypeError(f"Cannot specify both '{old}' (deprecated) and '{new}'. Use '{new}' only.")
kwargs[new] = kwargs.pop(old)
warnings.warn(
f"'{old}' is deprecated; use {new}='{kwargs[new]}' instead.",
DLCDeprecationWarning,
stacklevel=3,
)

# Drop unused parameters
for key in dropped_params:
if key in kwargs:
kwargs.pop(key)
warnings.warn(
f"'{key}' is a TensorFlow-only parameter and has no effect for PyTorch projects.",
DLCDeprecationWarning,
stacklevel=3,
)
return kwargs
Loading
Loading