Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0814b2c
Add external detector registry, base, and builder
C-Achard Apr 8, 2026
2e33596
Support external detectors in inference runner
C-Achard Apr 8, 2026
78d088b
Add MockExternalDetector and integration test
C-Achard Apr 8, 2026
4a7f55c
Clarify external detector training expectations
C-Achard Apr 8, 2026
f61dc88
Guard external detector eval/params with logs
C-Achard Apr 8, 2026
3102223
Support detector-based bbox matching and methods
C-Achard Apr 8, 2026
8bf875a
Use BBoxComputationMethod; avoid recomputing bboxes
C-Achard Apr 8, 2026
450383e
Add default bbox method and use enum
C-Achard Apr 8, 2026
136610b
Add bbox schemas and precomputed detector runner
C-Achard Apr 8, 2026
ee2ec85
Add tests for bbox sources and precomputed runner
C-Achard Apr 8, 2026
649f9ba
Extract BBoxComputationMethod to bboxes module
C-Achard Apr 8, 2026
9695c84
Use string enums and resolve bbox source
C-Achard Apr 8, 2026
c7f4fe6
Export PrecomputedDetectorRunner in detectors
C-Achard Apr 8, 2026
654b614
Use SQS keys and consistent paths in test config
C-Achard Apr 8, 2026
ef80d71
Fix task check in pose export tests
C-Achard Apr 8, 2026
b267a4d
Assign detector bbox for single candidate
C-Achard Apr 8, 2026
e50695b
Update test_precomputed_bbox.py
C-Achard Apr 8, 2026
cb13c0d
Use keypoints-derived bbox for matching
C-Achard Apr 8, 2026
dcde791
Add precomputed detector utilities
C-Achard Apr 8, 2026
c9c30ef
Add detector runners for top-down training
C-Achard Apr 8, 2026
4f0c997
Add precomputed bbox & external detector support
C-Achard Apr 8, 2026
5d95087
Add e2e test for precomputed top-down training
C-Achard Apr 8, 2026
916f5cd
Update deeplabcut/pose_estimation_pytorch/data/base.py
C-Achard Apr 9, 2026
01ba41e
Update deeplabcut/pose_estimation_pytorch/apis/training.py
C-Achard Apr 9, 2026
6ceae5b
Add DetectorToPoseInferenceRunner and wiring
C-Achard Apr 9, 2026
e1f30fa
Add tests for DetectorThenPoseInferenceRunner
C-Achard Apr 9, 2026
cd49492
Create detector_test_full_api.py
C-Achard Apr 9, 2026
6813b22
Support external detectors in pose config
C-Achard Apr 10, 2026
718e622
Refactor detector_test_full_api example and config
C-Achard Apr 10, 2026
747dd76
Add external detector workflow script
C-Achard Apr 10, 2026
129192c
Update external detector workflow examples
C-Achard Apr 10, 2026
e984788
Rename detector class and update docs
C-Achard Apr 14, 2026
7f9d376
Simplify post-run messages in workflow script
C-Achard Apr 14, 2026
5cae337
Add cxcywh bbox support and resilient annotation merge
C-Achard Apr 14, 2026
a2978b7
Add detector-to-pose runner and normalize predictions
C-Achard Apr 14, 2026
09ccf70
Update external_detector_workflow.py
C-Achard Apr 14, 2026
b6976a0
Add note to create training shuffle in docs
C-Achard Apr 14, 2026
b1f6f2d
Disable bbox_fallback_to_gt by default
C-Achard Apr 15, 2026
e74d8e6
Make bbox margin configurable in DLCLoader
C-Achard Jun 15, 2026
30c3db7
Allow None for numeric init parameters
C-Achard Jun 15, 2026
021b14a
Normalize detector outputs and optional max individuals
C-Achard Jun 15, 2026
6f843a6
Simplify DetectorToPose inference wrapper
C-Achard Jun 15, 2026
8058e41
Assert detector receives raw image inputs
C-Achard Jun 15, 2026
363b133
Apply suggestions from code review
C-Achard Jun 15, 2026
4704cec
Update detector_test_full_api.py
C-Achard Jun 15, 2026
eafdf0a
Update external_detector_workflow.py
C-Achard Jun 15, 2026
07325b0
Merge branch 'main' into cy/h-detectors
C-Achard Jun 15, 2026
664c59d
Style: reflow strings and remove trailing space
C-Achard Jun 15, 2026
b69e753
Swap decorator to the new DirectML fix
C-Achard Jun 15, 2026
451f9d1
Improve precomputed detector path matching
C-Achard Jun 15, 2026
5d5b52b
Add tests for precomputed runner path lookup
C-Achard Jun 15, 2026
1a0968f
Use device-specific GPU memory reporting
C-Achard Jun 15, 2026
c1c4077
Validate and handle precomputed detector bboxes
C-Achard Jun 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions deeplabcut/pose_estimation_pytorch/apis/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
)
from deeplabcut.pose_estimation_pytorch.data.collate import COLLATE_FUNCTIONS
from deeplabcut.pose_estimation_pytorch.models import DETECTORS, PoseModel
from deeplabcut.pose_estimation_pytorch.models.detectors.external.base import (
build_precomputed_detector_runner_from_config,
)
from deeplabcut.pose_estimation_pytorch.modelzoo.memory_replay import (
prepare_memory_replay,
)
Expand Down Expand Up @@ -142,8 +145,56 @@ def train(
logging.info(f" Training: {transform}")
logging.info(f" Validation: {inference_transform}")

train_dataset = loader.create_dataset(transform=transform, mode="train", task=task)
valid_dataset = loader.create_dataset(transform=inference_transform, mode="test", task=task)
train_detector_runner = None
valid_detector_runner = None

data_cfg = loader.model_cfg.get("data", {})
bbox_source = data_cfg.get("bbox_source")
precomputed_bboxes = data_cfg.get("precomputed_bboxes")

if task == Task.TOP_DOWN and bbox_source == "detection_bbox":
if not precomputed_bboxes:
raise ValueError(
"data.bbox_source='detection_bbox' was requested for top-down pose "
"training, but data.precomputed_bboxes is not configured. "
"Please provide a BBoxes JSON artifact or set data.bbox_source to "
"'gt' or 'keypoints'."
)

validate_image_paths = data_cfg.get("bbox_validate_image_paths", False)

train_detector_runner = build_precomputed_detector_runner_from_config(
loader.model_cfg,
mode="train",
target_format="xywh",
validate_image_paths=validate_image_paths,
)
valid_detector_runner = build_precomputed_detector_runner_from_config(
loader.model_cfg,
mode="test",
target_format="xywh",
validate_image_paths=validate_image_paths,
)
elif precomputed_bboxes:
logging.info(
"data.precomputed_bboxes is configured but data.bbox_source=%r. "
"Ignoring precomputed detector boxes for this training run.",
bbox_source,
)

train_dataset = loader.create_dataset(
transform=transform,
mode="train",
task=task,
detector_runner=train_detector_runner,
)

valid_dataset = loader.create_dataset(
transform=inference_transform,
mode="test",
task=task,
detector_runner=valid_detector_runner,
)

collate_fn = None
if collate_fn_cfg := run_config["data"]["train"].get("collate"):
Expand Down
167 changes: 157 additions & 10 deletions deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
from __future__ import annotations

import copy
import logging
from enum import Enum
from pathlib import Path
from typing import Literal

from deeplabcut.core.config import read_config_as_dict, write_config
from deeplabcut.core.weight_init import WeightInitialization
Expand All @@ -24,20 +27,68 @@
replace_default_values,
update_config,
)
from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod
from deeplabcut.pose_estimation_pytorch.runners.inference import InferenceConfig
from deeplabcut.pose_estimation_pytorch.task import Task
from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions

logger = logging.getLogger(__name__)


def _yaml_safe_value(value):
"""
Convert config values to YAML-safe built-in Python types.
- Enum -> enum.value
- Path -> POSIX string
- dict/list/tuple -> recurse
"""
if isinstance(value, Enum):
return value.value
if isinstance(value, Path):
return value.as_posix()
if isinstance(value, dict):
return {k: _yaml_safe_value(v) for k, v in value.items()}
if isinstance(value, list):
return [_yaml_safe_value(v) for v in value]
if isinstance(value, tuple):
return [_yaml_safe_value(v) for v in value]
return value


class DetectorMode(Enum):
NATIVE = "native"
EXTERNAL = "external"

@classmethod
def coerce_mode(
cls,
detector_mode: str | DetectorMode | None,
) -> DetectorMode | None:
if detector_mode is None:
return None
if isinstance(detector_mode, cls):
return detector_mode
norm = str(detector_mode).strip().lower()
if norm == "native":
return cls.NATIVE
if norm == "external":
return cls.EXTERNAL
raise ValueError(f"Unknown detector_mode: {detector_mode}")


def make_pytorch_pose_config(
project_config: dict,
pose_config_path: str | Path,
net_type: str | None = None,
top_down: bool = False,
detector_type: str | None = None,
detector_mode: Literal["native", "external"] | DetectorMode | None = None,
weight_init: WeightInitialization | None = None,
save: bool = False,
ctd_conditions: int | str | Path | tuple[int, str] | tuple[int, int] | None = None,
precomputed_bboxes: str | Path | None = None,
bbox_source: str | BBoxComputationMethod | None = None,
external_detector_metadata: dict | None = None,
) -> dict:
"""Creates a PyTorch pose configuration file for a DeepLabCut project.

Expand Down Expand Up @@ -65,8 +116,17 @@ def make_pytorch_pose_config(
by associating a detector to the pose model. Required for multi-animal
projects when net_type is a backbone (as a backbone + heatmap head can only
predict pose for single individuals).
detector_type: for top-down pose models, the architecture of the desired object
detector_type: for native top-down pose models, the architecture of the desired object
detection model
detector_mode:
Controls how top-down detector information is represented in the config.
- None: preserves legacy behavior
* if precomputed_bboxes is given -> external mode
* otherwise -> native detector mode
- "native": include a native DLC detector configuration
- "external": configure top-down pose training/inference to use external /
precomputed detector boxes instead of a native detector model.
If external, detector_type must be None and precomputed_bboxes must be provided.
weight_init: Specify how model weights should be initialized. If None, ImageNet
pretrained weights from Timm will be loaded when training.
save: Whether to save the model configuration file to the ``pose_config_path``.
Expand All @@ -79,7 +139,9 @@ def make_pytorch_pose_config(
predictions file.
* A shuffle number and a particular snapshot (ctd_conditions: tuple[int, str] | tuple[int, int]), which
respectively correspond to a bottom-up (BU) network type and a particular snapshot name or index.

precomputed_bboxes: str | Path, optional, default = None,
Path to a JSON artifact containing precomputed detector bounding boxes.
When provided with detector_mode=None, external detector mode is inferred.

Returns:
the PyTorch pose configuration file
Expand All @@ -90,14 +152,31 @@ def make_pytorch_pose_config(
bodyparts = auxiliaryfunctions.get_bodyparts(project_config)
unique_bpts = auxiliaryfunctions.get_unique_bodyparts(project_config)

if net_type is None:
net_type = project_config.get("default_net_type", "resnet_50")
if not net_type:
net_type = project_config.get("default_net_type")
if not net_type:
net_type = "resnet_50" # default backbone if net_type is not specified
logger.warning(f"No net_type specified in project config or as argument. Defaulting to {net_type}.")
if not isinstance(net_type, str):
raise TypeError(f"net_type must be a string, got {type(net_type)}")

configs_dir = get_config_folder_path()
pose_config = load_base_config(configs_dir)
pose_config = add_metadata(project_config, pose_config, pose_config_path)
pose_config["net_type"] = net_type

detector_mode = DetectorMode.coerce_mode(detector_mode)
if detector_mode is None:
if precomputed_bboxes is not None:
detector_mode = DetectorMode.EXTERNAL
else:
detector_mode = DetectorMode.NATIVE

if detector_mode == DetectorMode.EXTERNAL and not top_down and net_type in load_backbones(get_config_folder_path()):
raise ValueError(
"detector_mode='external' requires a top-down pose model. If using a backbone net_type, pass top_down=True."
)

backbones = load_backbones(configs_dir)
if net_type in backbones:
if not top_down and multianimal_project:
Expand Down Expand Up @@ -132,13 +211,47 @@ def make_pytorch_pose_config(
)

task = Task(model_cfg.get("method", "BU").upper())
if task == Task.TOP_DOWN:
model_cfg = add_detector(
configs_dir,
model_cfg,
len(individuals),
detector_type=detector_type,
if detector_mode == DetectorMode.EXTERNAL and task != Task.TOP_DOWN:
raise ValueError("detector_mode='external' can only be used with top-down pose models.")

if precomputed_bboxes is not None and task != Task.TOP_DOWN:
raise ValueError("precomputed_bboxes can only be used with top-down pose models.")
if detector_mode == DetectorMode.NATIVE and precomputed_bboxes is not None:
raise ValueError(
"precomputed_bboxes cannot be used with native detectors. If you want to use"
" precomputed boxes from an external detector, set detector_mode='external'."
)
if detector_mode == DetectorMode.EXTERNAL and detector_type is not None:
raise ValueError("detector_type cannot be used with detector_mode='external'.")
if (
task == Task.TOP_DOWN
and detector_mode == DetectorMode.NATIVE
and _yaml_safe_value(bbox_source) == BBoxComputationMethod.DETECTION_BBOX.value
and precomputed_bboxes is None
):
raise ValueError(
"bbox_source='detection_bbox' requires precomputed_bboxes when using "
"detector_mode='native'. If you want to train from external/offline detector "
"boxes, use detector_mode='external'."
)
if detector_mode != DetectorMode.EXTERNAL and external_detector_metadata is not None:
raise ValueError("external_detector_metadata can only be used with detector_mode='external'.")

if task == Task.TOP_DOWN:
if detector_mode == DetectorMode.NATIVE:
model_cfg = add_detector(
configs_dir,
model_cfg,
len(individuals),
detector_type=detector_type,
)
elif detector_mode == DetectorMode.EXTERNAL:
# Explicitly do NOT add a native detector model
model_cfg.setdefault("detector", {})
model_cfg["detector"].setdefault("train_settings", {})
model_cfg["detector"]["train_settings"]["epochs"] = 0
else:
raise ValueError(f"Unknown detector_mode: {detector_mode}")

# add the default augmentations to the config
aug_filename = "aug_default.yaml" if task == Task.BOTTOM_UP else "aug_top_down.yaml"
Expand All @@ -149,6 +262,39 @@ def make_pytorch_pose_config(
# add the model to the config
pose_config = update_config(pose_config, model_cfg)

# ------------------------------------------------------------------
# Configure bbox source / offline precomputed detector boxes
# ------------------------------------------------------------------
if "data" not in pose_config:
pose_config["data"] = {}

if detector_mode == DetectorMode.EXTERNAL and bbox_source is not None:
normalized_bbox_source = _yaml_safe_value(bbox_source)
if normalized_bbox_source != BBoxComputationMethod.DETECTION_BBOX.value:
raise ValueError("bbox_source must be 'detection_bbox' when detector_mode='external'.")

if detector_mode == DetectorMode.EXTERNAL:
if precomputed_bboxes is None:
raise ValueError("precomputed_bboxes is mandatory for external detector mode.")

pose_config["data"]["bbox_source"] = BBoxComputationMethod.DETECTION_BBOX.value
pose_config["data"]["precomputed_bboxes"] = Path(precomputed_bboxes).as_posix()

# Safe defaults for offline / precomputed detector matching
pose_config["data"].setdefault("bbox_match_iou_threshold", 0.1)
pose_config["data"].setdefault("bbox_fallback_to_gt", False)
pose_config["data"].setdefault("bbox_validate_image_paths", False)

elif bbox_source is not None:
pose_config["data"]["bbox_source"] = bbox_source

if detector_mode == DetectorMode.EXTERNAL:
pose_config.setdefault("metadata", {})
pose_config["metadata"]["detector"] = {
"mode": DetectorMode.EXTERNAL.value,
"info": _yaml_safe_value(external_detector_metadata or {}),
}

# set the dataset from which to load weights
if weight_init is not None:
pose_config["train_settings"]["weight_init"] = weight_init.to_dict()
Expand Down Expand Up @@ -194,6 +340,7 @@ def make_pytorch_pose_config(

# sort first-level keys to make it prettier
pose_config = dict(sorted(pose_config.items()))
pose_config = _yaml_safe_value(pose_config)

if save:
write_config(pose_config_path, pose_config, overwrite=True)
Expand Down
Loading
Loading