diff --git a/deeplabcut/pose_estimation_pytorch/apis/training.py b/deeplabcut/pose_estimation_pytorch/apis/training.py index dc636bc9f..eef6c7eb5 100644 --- a/deeplabcut/pose_estimation_pytorch/apis/training.py +++ b/deeplabcut/pose_estimation_pytorch/apis/training.py @@ -29,6 +29,9 @@ ) from deeplabcut.pose_estimation_pytorch.data.collate import COLLATE_FUNCTIONS from deeplabcut.pose_estimation_pytorch.models import DETECTORS, PoseModel +from deeplabcut.pose_estimation_pytorch.models.detectors.external.base import ( + build_precomputed_detector_runner_from_config, +) from deeplabcut.pose_estimation_pytorch.modelzoo.memory_replay import ( prepare_memory_replay, ) @@ -142,8 +145,56 @@ def train( logging.info(f" Training: {transform}") logging.info(f" Validation: {inference_transform}") - train_dataset = loader.create_dataset(transform=transform, mode="train", task=task) - valid_dataset = loader.create_dataset(transform=inference_transform, mode="test", task=task) + train_detector_runner = None + valid_detector_runner = None + + data_cfg = loader.model_cfg.get("data", {}) + bbox_source = data_cfg.get("bbox_source") + precomputed_bboxes = data_cfg.get("precomputed_bboxes") + + if task == Task.TOP_DOWN and bbox_source == "detection_bbox": + if not precomputed_bboxes: + raise ValueError( + "data.bbox_source='detection_bbox' was requested for top-down pose " + "training, but data.precomputed_bboxes is not configured. " + "Please provide a BBoxes JSON artifact or set data.bbox_source to " + "'gt' or 'keypoints'." + ) + + validate_image_paths = data_cfg.get("bbox_validate_image_paths", False) + + train_detector_runner = build_precomputed_detector_runner_from_config( + loader.model_cfg, + mode="train", + target_format="xywh", + validate_image_paths=validate_image_paths, + ) + valid_detector_runner = build_precomputed_detector_runner_from_config( + loader.model_cfg, + mode="test", + target_format="xywh", + validate_image_paths=validate_image_paths, + ) + elif precomputed_bboxes: + logging.info( + "data.precomputed_bboxes is configured but data.bbox_source=%r. " + "Ignoring precomputed detector boxes for this training run.", + bbox_source, + ) + + train_dataset = loader.create_dataset( + transform=transform, + mode="train", + task=task, + detector_runner=train_detector_runner, + ) + + valid_dataset = loader.create_dataset( + transform=inference_transform, + mode="test", + task=task, + detector_runner=valid_detector_runner, + ) collate_fn = None if collate_fn_cfg := run_config["data"]["train"].get("collate"): diff --git a/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py b/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py index 3cf316dd1..c47eb7c58 100644 --- a/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py +++ b/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py @@ -13,7 +13,10 @@ from __future__ import annotations import copy +import logging +from enum import Enum from pathlib import Path +from typing import Literal from deeplabcut.core.config import read_config_as_dict, write_config from deeplabcut.core.weight_init import WeightInitialization @@ -24,10 +27,54 @@ replace_default_values, update_config, ) +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod from deeplabcut.pose_estimation_pytorch.runners.inference import InferenceConfig from deeplabcut.pose_estimation_pytorch.task import Task from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions +logger = logging.getLogger(__name__) + + +def _yaml_safe_value(value): + """ + Convert config values to YAML-safe built-in Python types. + - Enum -> enum.value + - Path -> POSIX string + - dict/list/tuple -> recurse + """ + if isinstance(value, Enum): + return value.value + if isinstance(value, Path): + return value.as_posix() + if isinstance(value, dict): + return {k: _yaml_safe_value(v) for k, v in value.items()} + if isinstance(value, list): + return [_yaml_safe_value(v) for v in value] + if isinstance(value, tuple): + return [_yaml_safe_value(v) for v in value] + return value + + +class DetectorMode(Enum): + NATIVE = "native" + EXTERNAL = "external" + + @classmethod + def coerce_mode( + cls, + detector_mode: str | DetectorMode | None, + ) -> DetectorMode | None: + if detector_mode is None: + return None + if isinstance(detector_mode, cls): + return detector_mode + norm = str(detector_mode).strip().lower() + if norm == "native": + return cls.NATIVE + if norm == "external": + return cls.EXTERNAL + raise ValueError(f"Unknown detector_mode: {detector_mode}") + def make_pytorch_pose_config( project_config: dict, @@ -35,9 +82,13 @@ def make_pytorch_pose_config( net_type: str | None = None, top_down: bool = False, detector_type: str | None = None, + detector_mode: Literal["native", "external"] | DetectorMode | None = None, weight_init: WeightInitialization | None = None, save: bool = False, ctd_conditions: int | str | Path | tuple[int, str] | tuple[int, int] | None = None, + precomputed_bboxes: str | Path | None = None, + bbox_source: str | BBoxComputationMethod | None = None, + external_detector_metadata: dict | None = None, ) -> dict: """Creates a PyTorch pose configuration file for a DeepLabCut project. @@ -65,8 +116,17 @@ def make_pytorch_pose_config( by associating a detector to the pose model. Required for multi-animal projects when net_type is a backbone (as a backbone + heatmap head can only predict pose for single individuals). - detector_type: for top-down pose models, the architecture of the desired object + detector_type: for native top-down pose models, the architecture of the desired object detection model + detector_mode: + Controls how top-down detector information is represented in the config. + - None: preserves legacy behavior + * if precomputed_bboxes is given -> external mode + * otherwise -> native detector mode + - "native": include a native DLC detector configuration + - "external": configure top-down pose training/inference to use external / + precomputed detector boxes instead of a native detector model. + If external, detector_type must be None and precomputed_bboxes must be provided. weight_init: Specify how model weights should be initialized. If None, ImageNet pretrained weights from Timm will be loaded when training. save: Whether to save the model configuration file to the ``pose_config_path``. @@ -79,7 +139,9 @@ def make_pytorch_pose_config( predictions file. * A shuffle number and a particular snapshot (ctd_conditions: tuple[int, str] | tuple[int, int]), which respectively correspond to a bottom-up (BU) network type and a particular snapshot name or index. - + precomputed_bboxes: str | Path, optional, default = None, + Path to a JSON artifact containing precomputed detector bounding boxes. + When provided with detector_mode=None, external detector mode is inferred. Returns: the PyTorch pose configuration file @@ -90,14 +152,31 @@ def make_pytorch_pose_config( bodyparts = auxiliaryfunctions.get_bodyparts(project_config) unique_bpts = auxiliaryfunctions.get_unique_bodyparts(project_config) - if net_type is None: - net_type = project_config.get("default_net_type", "resnet_50") + if not net_type: + net_type = project_config.get("default_net_type") + if not net_type: + net_type = "resnet_50" # default backbone if net_type is not specified + logger.warning(f"No net_type specified in project config or as argument. Defaulting to {net_type}.") + if not isinstance(net_type, str): + raise TypeError(f"net_type must be a string, got {type(net_type)}") configs_dir = get_config_folder_path() pose_config = load_base_config(configs_dir) pose_config = add_metadata(project_config, pose_config, pose_config_path) pose_config["net_type"] = net_type + detector_mode = DetectorMode.coerce_mode(detector_mode) + if detector_mode is None: + if precomputed_bboxes is not None: + detector_mode = DetectorMode.EXTERNAL + else: + detector_mode = DetectorMode.NATIVE + + if detector_mode == DetectorMode.EXTERNAL and not top_down and net_type in load_backbones(get_config_folder_path()): + raise ValueError( + "detector_mode='external' requires a top-down pose model. If using a backbone net_type, pass top_down=True." + ) + backbones = load_backbones(configs_dir) if net_type in backbones: if not top_down and multianimal_project: @@ -132,13 +211,47 @@ def make_pytorch_pose_config( ) task = Task(model_cfg.get("method", "BU").upper()) - if task == Task.TOP_DOWN: - model_cfg = add_detector( - configs_dir, - model_cfg, - len(individuals), - detector_type=detector_type, + if detector_mode == DetectorMode.EXTERNAL and task != Task.TOP_DOWN: + raise ValueError("detector_mode='external' can only be used with top-down pose models.") + + if precomputed_bboxes is not None and task != Task.TOP_DOWN: + raise ValueError("precomputed_bboxes can only be used with top-down pose models.") + if detector_mode == DetectorMode.NATIVE and precomputed_bboxes is not None: + raise ValueError( + "precomputed_bboxes cannot be used with native detectors. If you want to use" + " precomputed boxes from an external detector, set detector_mode='external'." + ) + if detector_mode == DetectorMode.EXTERNAL and detector_type is not None: + raise ValueError("detector_type cannot be used with detector_mode='external'.") + if ( + task == Task.TOP_DOWN + and detector_mode == DetectorMode.NATIVE + and _yaml_safe_value(bbox_source) == BBoxComputationMethod.DETECTION_BBOX.value + and precomputed_bboxes is None + ): + raise ValueError( + "bbox_source='detection_bbox' requires precomputed_bboxes when using " + "detector_mode='native'. If you want to train from external/offline detector " + "boxes, use detector_mode='external'." ) + if detector_mode != DetectorMode.EXTERNAL and external_detector_metadata is not None: + raise ValueError("external_detector_metadata can only be used with detector_mode='external'.") + + if task == Task.TOP_DOWN: + if detector_mode == DetectorMode.NATIVE: + model_cfg = add_detector( + configs_dir, + model_cfg, + len(individuals), + detector_type=detector_type, + ) + elif detector_mode == DetectorMode.EXTERNAL: + # Explicitly do NOT add a native detector model + model_cfg.setdefault("detector", {}) + model_cfg["detector"].setdefault("train_settings", {}) + model_cfg["detector"]["train_settings"]["epochs"] = 0 + else: + raise ValueError(f"Unknown detector_mode: {detector_mode}") # add the default augmentations to the config aug_filename = "aug_default.yaml" if task == Task.BOTTOM_UP else "aug_top_down.yaml" @@ -149,6 +262,39 @@ def make_pytorch_pose_config( # add the model to the config pose_config = update_config(pose_config, model_cfg) + # ------------------------------------------------------------------ + # Configure bbox source / offline precomputed detector boxes + # ------------------------------------------------------------------ + if "data" not in pose_config: + pose_config["data"] = {} + + if detector_mode == DetectorMode.EXTERNAL and bbox_source is not None: + normalized_bbox_source = _yaml_safe_value(bbox_source) + if normalized_bbox_source != BBoxComputationMethod.DETECTION_BBOX.value: + raise ValueError("bbox_source must be 'detection_bbox' when detector_mode='external'.") + + if detector_mode == DetectorMode.EXTERNAL: + if precomputed_bboxes is None: + raise ValueError("precomputed_bboxes is mandatory for external detector mode.") + + pose_config["data"]["bbox_source"] = BBoxComputationMethod.DETECTION_BBOX.value + pose_config["data"]["precomputed_bboxes"] = Path(precomputed_bboxes).as_posix() + + # Safe defaults for offline / precomputed detector matching + pose_config["data"].setdefault("bbox_match_iou_threshold", 0.1) + pose_config["data"].setdefault("bbox_fallback_to_gt", False) + pose_config["data"].setdefault("bbox_validate_image_paths", False) + + elif bbox_source is not None: + pose_config["data"]["bbox_source"] = bbox_source + + if detector_mode == DetectorMode.EXTERNAL: + pose_config.setdefault("metadata", {}) + pose_config["metadata"]["detector"] = { + "mode": DetectorMode.EXTERNAL.value, + "info": _yaml_safe_value(external_detector_metadata or {}), + } + # set the dataset from which to load weights if weight_init is not None: pose_config["train_settings"]["weight_init"] = weight_init.to_dict() @@ -194,6 +340,7 @@ def make_pytorch_pose_config( # sort first-level keys to make it prettier pose_config = dict(sorted(pose_config.items())) + pose_config = _yaml_safe_value(pose_config) if save: write_config(pose_config_path, pose_config, overwrite=True) diff --git a/deeplabcut/pose_estimation_pytorch/data/base.py b/deeplabcut/pose_estimation_pytorch/data/base.py index 4795826ab..c88ebf059 100644 --- a/deeplabcut/pose_estimation_pytorch/data/base.py +++ b/deeplabcut/pose_estimation_pytorch/data/base.py @@ -10,14 +10,19 @@ # from __future__ import annotations +import copy +import logging from abc import ABC, abstractmethod from pathlib import Path +from typing import Protocol import albumentations as A import numpy as np +from scipy.optimize import linear_sum_assignment import deeplabcut.core.config as config_utils import deeplabcut.pose_estimation_pytorch.config as config +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod from deeplabcut.pose_estimation_pytorch.data.dataset import ( PoseDataset, PoseDatasetParameters, @@ -33,6 +38,23 @@ ) from deeplabcut.pose_estimation_pytorch.task import Task +logger = logging.getLogger(__name__) + + +class DetectorRunnerLike(Protocol): + """Minimal protocol for any detector inference runner used by the data layer.""" + + def inference( + self, + images, + shelf_writer=None, + ) -> list[dict[str, np.ndarray]]: + """ + Expected final postprocessed DLC output per image, e.g. + {"bboxes": np.ndarray[N, 4], "bbox_scores": np.ndarray[N]} + """ + ... + class Loader(ABC): """Abstract class that represents a blueprint for loading and processing dataset @@ -44,7 +66,7 @@ class Loader(ABC): create_dataset(images: dict = None, annotations: dict = None, transform: object = None, mode: str = "train", task: Task = Task.BOTTOM_UP) -> PoseDataset: Creates and returns a PoseDataset given a set of images, annotations, and other parameters. - _compute_bboxes(images, annotations, method: str = 'gt') -> dict: + _compute_bboxes(images, annotations, method: BBoxComputationMethod | str = BBoxComputationMethod.GT) -> dict: Retrieves all bounding boxes based on the specified method. get_dataset_parameters(*args, **kwargs) -> dict: Returns a dictionary containing dataset parameters derived from the configuration. @@ -125,6 +147,15 @@ def image_filenames(self, mode: str = "train") -> list[str]: data = self._loaded_data[mode] return [image["file_name"] for image in data["images"]] + def default_bbox_method(self, task: Task) -> BBoxComputationMethod | None: + """ + Returns the default bbox source for this loader/task. + Subclasses may override this to preserve legacy behavior. + """ + if task in (Task.TOP_DOWN, Task.DETECT): + return BBoxComputationMethod.GT + return None + def ground_truth_keypoints(self, mode: str = "train", unique_bodypart: bool = False) -> dict[str, np.ndarray]: """Creates a dictionary containing the ground truth data. @@ -230,23 +261,31 @@ def create_dataset( transform: A.BaseCompose | None = None, mode: str = "train", task: Task = Task.BOTTOM_UP, + detector_runner: DetectorRunnerLike | None = None, ) -> PoseDataset: - """Creates a PoseDataset based on provided arguments. + """Creates a PoseDataset based on provided arguments.""" - Args: - transform: Transformation to be applied on dataset. Defaults to None. - mode: Mode in which dataset is to be used (e.g., 'train', 'test'). Defaults to 'train'. - task: Task for which the dataset is being used. Defaults to 'BU'. - - Returns: - PoseDataset: An instance of the PoseDataset class. - - Raises: - Any exception raised by `get_dataset_parameters` or `load_data` methods. - """ parameters = self.get_dataset_parameters() data = self.load_data(mode) - data["annotations"] = self.filter_annotations(data["annotations"], task) + + # load_data() is cached -> never mutate cached annotations + images = data["images"] + annotations = copy.deepcopy(data["annotations"]) + + if task in (Task.TOP_DOWN, Task.DETECT): + bbox_method = self._resolve_bbox_method(task=task, detector_runner=detector_runner) + annotations = self._compute_bboxes( + images=images, + annotations=annotations, + method=bbox_method, + bbox_margin=self.model_cfg["data"].get("bbox_margin", 20), + detector_runner=detector_runner, + bbox_iou_threshold=self.model_cfg["data"].get("bbox_match_iou_threshold", 0.1), + fallback_to_gt=self.model_cfg["data"].get("bbox_fallback_to_gt", False), + ) + + annotations = self.filter_annotations(annotations, task) + ctd_config = None if self.pose_task == Task.COND_TOP_DOWN: ctd_config = GenSamplingConfig( @@ -255,8 +294,8 @@ def create_dataset( ) dataset = PoseDataset( - images=data["images"], - annotations=data["annotations"], + images=images, + annotations=annotations, transform=transform, mode=mode, task=task, @@ -300,12 +339,94 @@ def filter_annotations(annotations: list[dict], task: Task) -> list[dict]: return filtered_annotations + def _resolve_bbox_method( + self, + task: Task, + detector_runner: DetectorRunnerLike | None, + ) -> BBoxComputationMethod | None: + """ + Priority: + 1. detector_runner provided -> detector boxes + 2. explicit config bbox_source + 3. loader/task default + """ + if detector_runner is not None: + return BBoxComputationMethod.DETECTION_BBOX + + configured = self.model_cfg["data"].get("bbox_source") + if configured is not None: + return self._coerce_bbox_method(configured) + + default = self.default_bbox_method(task) + if default is not None: + return self._coerce_bbox_method(default) + + return None + + @staticmethod + def _coerce_bbox_method( + method: BBoxComputationMethod | str | None, + ) -> BBoxComputationMethod | None: + if method is None: + return None + if isinstance(method, BBoxComputationMethod): + return method + + normalized = method.strip().lower().replace(" ", "_") + aliases = { + "gt": BBoxComputationMethod.GT, + "keypoints": BBoxComputationMethod.KEYPOINTS, + "detection_bbox": BBoxComputationMethod.DETECTION_BBOX, + "detector": BBoxComputationMethod.DETECTION_BBOX, + "segmentation_mask": BBoxComputationMethod.SEGMENTATION_MASK, + } + try: + return aliases[normalized] + except KeyError as e: + raise ValueError(f"Invalid bbox computation method: {method}") from e + + @staticmethod + def _get_reference_bbox_for_matching( + annotation: dict, + image_h: int, + image_w: int, + bbox_margin: int, + ) -> np.ndarray: + """ + Returns the reference bbox to use when matching detector predictions to annotations. + + Priority: + 1. derive bbox from keypoints if possible + 2. fall back to annotation["bbox"] if present + 3. raise if neither is available + """ + keypoints = annotation.get("keypoints") + if keypoints is not None: + keypoints = np.asarray(keypoints, dtype=np.float32) + if keypoints.size > 0: + visible = keypoints[..., 2] > 0 + if np.any(visible): + return bbox_from_keypoints( + keypoints=keypoints, + image_h=image_h, + image_w=image_w, + margin=bbox_margin, + ).astype(np.float32) + + if "bbox" in annotation: + return np.asarray(annotation["bbox"], dtype=np.float32) + + raise ValueError("Cannot build reference bbox for matching: annotation has neither visible keypoints nor bbox.") + @staticmethod def _compute_bboxes( images: list[dict], annotations: list[dict], - method: str = "gt", + method: BBoxComputationMethod | str = BBoxComputationMethod.GT, bbox_margin: int = 20, + detector_runner: DetectorRunnerLike | None = None, + bbox_iou_threshold: float = 0.1, + fallback_to_gt: bool = False, ): """TODO: Nastya method of bbox computation (detection bbox, seg. mask, ...) Retrieves all bounding boxes based on the given method. @@ -328,24 +449,25 @@ def _compute_bboxes( ValueError: If method is not one of 'gt', 'detection bbox', 'keypoints', or 'segmentation mask'. """ - if not method: + method = Loader._coerce_bbox_method(method) + if method is None: return annotations - elif method == "gt": - for _i, annotation in enumerate(annotations): + if fallback_to_gt and method != BBoxComputationMethod.DETECTION_BBOX: + logger.warning( + "bbox_fallback_to_gt is only applicable when method='detection bbox'. Ignoring fallback_to_gt." + ) + if method == BBoxComputationMethod.GT: + for annotation in annotations: if "bbox" not in annotation: - # or do something else? raise ValueError( - f"Bounding box not found in annotation {annotation}, please " - "chose another bbox computation method" + f"Bounding box not found in annotation {annotation}, " + "please choose another bbox computation method" ) return annotations - elif method == "detection bbox": - raise NotImplementedError - - elif method == "keypoints": - min_area = 1 # TODO: should not be hardcoded + elif method == BBoxComputationMethod.KEYPOINTS: + min_area = 1 img_id_to_annotations = map_id_to_annotations(annotations) for img in images: anns = [annotations[idx] for idx in img_id_to_annotations[img["id"]]] @@ -359,8 +481,183 @@ def _compute_bboxes( a["area"] = max(min_area, (a["bbox"][2] * a["bbox"][3]).item()) return annotations - elif method == "segmentation mask": + elif method == BBoxComputationMethod.DETECTION_BBOX: + if detector_runner is None: + raise ValueError("detector_runner must be provided when method='detection bbox'") + + img_id_to_annotations = map_id_to_annotations(annotations) + image_inputs = [img["file_name"] for img in images] + predictions = detector_runner.inference(image_inputs) + + if len(predictions) != len(images): + raise ValueError(f"Detector returned {len(predictions)} predictions for {len(images)} images") + + num_unmatched = 0 + num_total = 0 + + for img, pred in zip(images, predictions, strict=False): + ann_indices = img_id_to_annotations[img["id"]] + + # Only match real individuals, not unique-bodypart-only annotations + candidate_ann_indices = [idx for idx in ann_indices if annotations[idx].get("category_id", 1) == 1] + + if len(candidate_ann_indices) == 0: + continue + + pred_bboxes = np.asarray(pred.get("bboxes", np.zeros((0, 4))), dtype=np.float32).reshape(-1, 4) + pred_scores = np.asarray( + pred.get("bbox_scores", np.ones((len(pred_bboxes),), dtype=np.float32)), + dtype=np.float32, + ).reshape(-1) + + gt_bboxes = np.stack( + [ + Loader._get_reference_bbox_for_matching( + annotation=annotations[idx], + image_h=img["height"], + image_w=img["width"], + bbox_margin=bbox_margin, + ) + for idx in candidate_ann_indices + ], + axis=0, + ) + + # Simple / common case: one candidate annotation only (single animal). + # In this case, trust the detector and assign directly the highest scoring bbox rather than trying + # to IoU-match against a potentially stale placeholder bbox. + if len(candidate_ann_indices) == 1 and len(pred_bboxes) > 0: + ann_idx = candidate_ann_indices[0] + if len(pred_scores) == len(pred_bboxes): + pred_idx = int(np.argmax(pred_scores)) + else: + pred_idx = 0 + matched_bbox = pred_bboxes[pred_idx].astype(np.float32, copy=True) + annotations[ann_idx]["bbox"] = matched_bbox + annotations[ann_idx]["area"] = max(1.0, float(matched_bbox[2] * matched_bbox[3])) + num_total += 1 + continue + + matches = Loader._match_bboxes_iou( + gt_bboxes=gt_bboxes, + pred_bboxes=pred_bboxes, + pred_scores=pred_scores, + iou_threshold=bbox_iou_threshold, + ) + + num_total += len(candidate_ann_indices) + + for local_gt_idx, ann_idx in enumerate(candidate_ann_indices): + pred_idx = matches.get(local_gt_idx, None) + + if pred_idx is None: + num_unmatched += 1 + if not fallback_to_gt: + annotations[ann_idx]["bbox"] = np.zeros((4,), dtype=np.float32) + annotations[ann_idx]["area"] = 0.0 + continue + + matched_bbox = pred_bboxes[pred_idx].astype(np.float32, copy=True) + annotations[ann_idx]["bbox"] = matched_bbox + annotations[ann_idx]["area"] = max(1.0, float(matched_bbox[2] * matched_bbox[3])) + + if num_total > 0 and num_unmatched > 0: + logging.info( + f"Detector bbox matching: {num_total - num_unmatched}/{num_total} annotations matched " + f"(fallback_to_gt={fallback_to_gt})" + ) + if not fallback_to_gt: + logging.error( + f"{num_unmatched} annotations were not matched to any detection bbox " + "and were assigned empty bounding boxes. " + "Please review the detector performance!" + "If this is expected and/or gt fallback is reasonable in your case, " + "consider setting bbox_fallback_to_gt=True in the config to use gt bboxes " + "for unmatched annotations instead of empty bboxes." + ) + + return annotations + + if method == BBoxComputationMethod.SEGMENTATION_MASK: raise NotImplementedError - else: - raise ValueError(f"Unknown method: {method}") + raise ValueError(f"Unknown method: {method}") + + @staticmethod + def _xywh_to_xyxy(boxes: np.ndarray) -> np.ndarray: + """Convert boxes from xywh -> xyxy.""" + boxes = np.asarray(boxes, dtype=np.float32).copy() + if boxes.size == 0: + return boxes.reshape(0, 4) + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + return boxes + + @staticmethod + def _bbox_iou_xywh(boxes_a: np.ndarray, boxes_b: np.ndarray) -> np.ndarray: + """ + Compute pairwise IoU between two sets of boxes in xywh format. + Returns matrix of shape [len(boxes_a), len(boxes_b)]. + """ + boxes_a = Loader._xywh_to_xyxy(boxes_a) + boxes_b = Loader._xywh_to_xyxy(boxes_b) + + if len(boxes_a) == 0 or len(boxes_b) == 0: + return np.zeros((len(boxes_a), len(boxes_b)), dtype=np.float32) + + ious = np.zeros((len(boxes_a), len(boxes_b)), dtype=np.float32) + for i, a in enumerate(boxes_a): + ax1, ay1, ax2, ay2 = a + a_area = max(0.0, ax2 - ax1) * max(0.0, ay2 - ay1) + + for j, b in enumerate(boxes_b): + bx1, by1, bx2, by2 = b + b_area = max(0.0, bx2 - bx1) * max(0.0, by2 - by1) + + ix1 = max(ax1, bx1) + iy1 = max(ay1, by1) + ix2 = min(ax2, bx2) + iy2 = min(ay2, by2) + + iw = max(0.0, ix2 - ix1) + ih = max(0.0, iy2 - iy1) + inter = iw * ih + + union = a_area + b_area - inter + if union > 0: + ious[i, j] = inter / union + + return ious + + @staticmethod + def _match_bboxes_iou( + gt_bboxes: np.ndarray, + pred_bboxes: np.ndarray, + pred_scores: np.ndarray | None = None, + iou_threshold: float = 0.1, + ) -> dict[int, int]: + """ + Match predicted boxes to GT boxes using Hungarian assignment on IoU cost. + + Returns: + dict mapping local_gt_index -> pred_index + """ + if len(gt_bboxes) == 0 or len(pred_bboxes) == 0: + return {} + + iou = Loader._bbox_iou_xywh(gt_bboxes, pred_bboxes) + + # Prefer higher score very slightly when IoUs are tied + cost = 1.0 - iou + if pred_scores is not None and len(pred_scores) == pred_bboxes.shape[0]: + score_penalty = (1.0 - pred_scores.reshape(1, -1)) * 1e-6 + cost = cost + score_penalty + + gt_idx, pred_idx = linear_sum_assignment(cost) + + matches: dict[int, int] = {} + for g, p in zip(gt_idx, pred_idx, strict=False): + if iou[g, p] >= iou_threshold: + matches[int(g)] = int(p) + + return matches diff --git a/deeplabcut/pose_estimation_pytorch/data/bboxes.py b/deeplabcut/pose_estimation_pytorch/data/bboxes.py new file mode 100644 index 000000000..f9052d00e --- /dev/null +++ b/deeplabcut/pose_estimation_pytorch/data/bboxes.py @@ -0,0 +1,254 @@ +from __future__ import annotations + +from enum import Enum +from pathlib import Path +from typing import Any, Literal, TypeAlias, TypedDict + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +# ----------------------------------------------------------------------------- +# Types +# ----------------------------------------------------------------------------- + +BBoxFormat = Literal["xywh", "xyxy", "cxcywh"] +EvalMode: TypeAlias = Literal["train", "test"] + + +class BBoxComputationMethod(str, Enum): + GT = "gt" + KEYPOINTS = "keypoints" + DETECTION_BBOX = "detection_bbox" + SEGMENTATION_MASK = "segmentation_mask" + + +class DetectorContext(TypedDict): + bboxes: np.ndarray + bbox_scores: np.ndarray + + +ImageWithContext: TypeAlias = tuple[Path, DetectorContext] +ImagesWithContext: TypeAlias = list[ImageWithContext] + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _numpy_to_jsonable(obj: Any) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, dict): + return {k: _numpy_to_jsonable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_numpy_to_jsonable(x) for x in obj] + return obj + + +def _xyxy_to_xywh(boxes: np.ndarray) -> np.ndarray: + """Assumes top-left origin. Converts [x_min, y_min, x_max, y_max] to [x_min, y_min, width, height].""" + boxes = np.asarray(boxes, dtype=np.float32).copy().reshape(-1, 4) + if len(boxes) == 0: + return boxes + boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + return boxes + + +def _xywh_to_xyxy(boxes: np.ndarray) -> np.ndarray: + """Assumes top-left origin. Converts [x_min, y_min, width, height] to [x_min, y_min, x_max, y_max].""" + boxes = np.asarray(boxes, dtype=np.float32).copy().reshape(-1, 4) + if len(boxes) == 0: + return boxes + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + return boxes + + +def _cxcywh_to_xyxy(boxes): + """Converts [center_x, center_y, width, height] to [x_min, y_min, x_max, y_max].""" + boxes = np.asarray(boxes, dtype=np.float32).copy().reshape(-1, 4) + if len(boxes) == 0: + return boxes + x, y, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3] + boxes[:, 0] = x - w / 2 + boxes[:, 1] = y - h / 2 + boxes[:, 2] = x + w / 2 + boxes[:, 3] = y + h / 2 + return boxes + + +# ----------------------------------------------------------------------------- +# Base model +# ----------------------------------------------------------------------------- + + +class StrictBaseModel(BaseModel): + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + +# ----------------------------------------------------------------------------- +# BBox schemas +# ----------------------------------------------------------------------------- + + +class BBoxEntry(StrictBaseModel): + """ + Bounding box output for one image. + + `bboxes` are stored in pixel coordinates, with format declared by `bbox_format`. + `bbox_scores` is aligned one-to-one with `bboxes`. + """ + + bboxes: list[tuple[float, float, float, float]] + bbox_scores: list[float] + bbox_format: BBoxFormat = "xyxy" + image_path: Path | None = None + + @classmethod + def from_detector_context( + cls, + context: DetectorContext, + *, + image_path: Path | None = None, + bbox_format: BBoxFormat = "xywh", + ) -> BBoxEntry: + """ + Build a schema entry from DLC-style detector context. + + Args: + context: + Expected format: + { + "bboxes": np.ndarray[N, 4], + "bbox_scores": np.ndarray[N] + } + image_path: + Optional path of the corresponding image. + bbox_format: + Format of `context["bboxes"]`. + Use: + - "xywh" for DLC postprocessed detector outputs / top-down context + - "xyxy" if adapting raw detector outputs before DLC postprocessing + + Returns: + BBoxEntry + """ + if "bboxes" not in context: + raise ValueError("Detector context must contain 'bboxes'.") + + bboxes = np.asarray(context["bboxes"], dtype=np.float32).reshape(-1, 4) + + if "bbox_scores" in context: + scores = np.asarray(context["bbox_scores"], dtype=np.float32).reshape(-1) + else: + # Allow score-less contexts, but fill with 1.0 + scores = np.ones((len(bboxes),), dtype=np.float32) + + if len(scores) != len(bboxes): + raise ValueError(f"Expected one bbox score per bbox, but got {len(scores)} scores for {len(bboxes)} boxes.") + + return cls( + bboxes=[tuple(map(float, box)) for box in bboxes], + bbox_scores=[float(s) for s in scores], + bbox_format=bbox_format, + image_path=image_path, + ) + + def to_array(self, *, dtype: np.dtype[Any] = np.float32) -> np.ndarray: + """Return bboxes as a NumPy array of shape [N, 4].""" + return np.asarray(self.bboxes, dtype=dtype).reshape(-1, 4) + + def to_xywh(self, *, dtype: np.dtype[Any] = np.float32) -> np.ndarray: + """Return bboxes in xywh format.""" + boxes = self.to_array(dtype=dtype) + if self.bbox_format == "xyxy": + boxes = _xyxy_to_xywh(boxes) + elif self.bbox_format == "cxcywh": + boxes = _xyxy_to_xywh(_cxcywh_to_xyxy(boxes)) + return boxes + + def to_xyxy(self, *, dtype: np.dtype[Any] = np.float32) -> np.ndarray: + """Return bboxes in xyxy format.""" + boxes = self.to_array(dtype=dtype) + if self.bbox_format == "xywh": + boxes = _xywh_to_xyxy(boxes) + elif self.bbox_format == "cxcywh": + boxes = _cxcywh_to_xyxy(boxes) + return boxes + + def to_detector_context( + self, + *, + dtype: np.dtype[Any] = np.float32, + target_format: BBoxFormat = "xywh", + ) -> DetectorContext: + """ + Convert this entry to DLC detector context format. + + Args: + dtype: + NumPy dtype for emitted arrays. + target_format: + Desired bbox format in the returned context. + For most DLC top-down dataset / pose use, this should be "xywh". + + Returns: + { + "bboxes": np.ndarray[N, 4], + "bbox_scores": np.ndarray[N], + } + """ + if target_format == "xywh": + bboxes = self.to_xywh(dtype=dtype) + else: + bboxes = self.to_xyxy(dtype=dtype) + + return { + "bboxes": bboxes, + "bbox_scores": np.asarray(self.bbox_scores, dtype=dtype), + } + + +class BBoxes(StrictBaseModel): + train: list[BBoxEntry] = Field(default_factory=list) + test: list[BBoxEntry] = Field(default_factory=list) + + @classmethod + def from_file(cls, json_file: Path, missing_ok: bool = False) -> BBoxes: + if not json_file.exists(): + if missing_ok: + return cls() + raise FileNotFoundError(f"BBoxes file not found: {json_file}") + return cls.from_json(json_file.read_text(encoding="utf-8")) + + @classmethod + def from_json(cls, json_str: str) -> BBoxes: + return cls.model_validate_json(json_str) + + def dump_json(self, json_file: Path) -> None: + Path(json_file).parent.mkdir(parents=True, exist_ok=True) + json_file.write_text(self.model_dump_json(indent=4), encoding="utf-8") + + def to_images_with_context( + self, + image_paths: list[Path], + mode: EvalMode, + *, + target_format: BBoxFormat = "xywh", + ) -> ImagesWithContext: + """ + Zip image paths with detector context in DLC expected format. + """ + mode_bboxes = getattr(self, mode) + if len(image_paths) != len(mode_bboxes): + raise ValueError(f"Got {len(image_paths)} {mode} images but {len(mode_bboxes)} bbox entries.") + + return [ + ( + image_path, + bbox_entry.to_detector_context(target_format=target_format), + ) + for image_path, bbox_entry in zip(image_paths, mode_bboxes, strict=False) + ] diff --git a/deeplabcut/pose_estimation_pytorch/data/cocoloader.py b/deeplabcut/pose_estimation_pytorch/data/cocoloader.py index 7fdea47f4..15d724a4f 100644 --- a/deeplabcut/pose_estimation_pytorch/data/cocoloader.py +++ b/deeplabcut/pose_estimation_pytorch/data/cocoloader.py @@ -18,6 +18,7 @@ import numpy as np from deeplabcut.pose_estimation_pytorch.data.base import Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters from deeplabcut.pose_estimation_pytorch.data.utils import ( map_id_to_annotations, @@ -274,7 +275,7 @@ def load_data(self, mode: str = "train") -> dict: annotations_with_bbox = self._compute_bboxes( data["images"], data["annotations"], - method="gt", + method=BBoxComputationMethod.GT, ) data["annotations"] = annotations_with_bbox return data diff --git a/deeplabcut/pose_estimation_pytorch/data/dlcloader.py b/deeplabcut/pose_estimation_pytorch/data/dlcloader.py index c034f0e74..6034c591f 100644 --- a/deeplabcut/pose_estimation_pytorch/data/dlcloader.py +++ b/deeplabcut/pose_estimation_pytorch/data/dlcloader.py @@ -25,9 +25,11 @@ from deeplabcut.core.engine import Engine from deeplabcut.generate_training_dataset.trainingsetmanipulation import drop_likelihood_columns from deeplabcut.pose_estimation_pytorch.data.base import Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters from deeplabcut.pose_estimation_pytorch.data.snapshots import Snapshot from deeplabcut.pose_estimation_pytorch.data.utils import bbox_from_keypoints, read_image_shape_fast +from deeplabcut.pose_estimation_pytorch.task import Task class DLCLoader(Loader): @@ -174,6 +176,15 @@ def get_dataset_parameters(self) -> PoseDatasetParameters: top_down_crop_with_context=crop_with_context, ) + def default_bbox_method(self, task: Task) -> BBoxComputationMethod | None: + """ + Preserve historical DLCLoader behavior: + for detector and top-down tasks, derive boxes from keypoints unless explicitly overridden. + """ + if task in (Task.TOP_DOWN, Task.DETECT): + return BBoxComputationMethod.KEYPOINTS + return None + def load_data(self, mode: str = "train") -> dict: """Loads DeepLabCut data into COCO-style annotations. @@ -197,14 +208,14 @@ def load_data(self, mode: str = "train") -> dict: raise ValueError(f"No data in {mode} split for this shuffle!") params = self.get_dataset_parameters() - data = self.to_coco(str(self._project_root), self._dfs[mode], params) - with_bbox = self._compute_bboxes( - data["images"], - data["annotations"], - method="keypoints", - bbox_margin=self.model_cfg["data"].get("bbox_margin", 20), - ) - data["annotations"] = with_bbox + bbox_margin = self.model_cfg["data"].get("bbox_margin", 20) + data = self.to_coco(str(self._project_root), self._dfs[mode], params, bbox_margin=bbox_margin) + + # `to_coco(...)` initializes keypoint-derived GT bboxes for compatibility + # with APIs that consume `load_data()` directly. The margin is config-driven. + # + # `create_dataset(...)` still owns the effective training bbox source and may + # rewrite these boxes according to bbox_source / detector_runner. return data def load_ground_truth( @@ -363,6 +374,7 @@ def to_coco( project_root: str | Path, df: pd.DataFrame, parameters: PoseDatasetParameters, + bbox_margin: int = 20, ) -> dict: """Formerly Shaokai's function. @@ -370,6 +382,7 @@ def to_coco( project_root: the path to the project root df: the DLC-format annotation dataframe to convert to a COCO-format dict parameters: the parameters for pose estimation + bbox_margin: the margin to add around the bounding boxes Returns: the coco format data @@ -464,12 +477,12 @@ def to_coco( ) coco_dict = {"annotations": anns, "categories": categories, "images": images} - coco_dict = DLCLoader._add_bbox_annotations(coco_dict) + coco_dict = DLCLoader._add_bbox_annotations(coco_dict, bbox_margin=bbox_margin) coco_dict = DLCLoader._remove_nans(coco_dict) return coco_dict @staticmethod - def _add_bbox_annotations(coco_dict: dict) -> dict: + def _add_bbox_annotations(coco_dict: dict, bbox_margin: int = 20) -> dict: for annotation in coco_dict.get("annotations", []): if "bbox" not in annotation: image = [img for img in coco_dict.get("images") if img.get("id") == annotation.get("image_id")][0] @@ -477,7 +490,7 @@ def _add_bbox_annotations(coco_dict: dict) -> dict: keypoints=np.array(annotation["keypoints"]), # (..., num_keypoints, xy) image_h=image.get("height"), image_w=image.get("width"), - margin=20, + margin=bbox_margin, ) annotation["bbox"] = list(bbox) return coco_dict diff --git a/deeplabcut/pose_estimation_pytorch/data/utils.py b/deeplabcut/pose_estimation_pytorch/data/utils.py index d4f6797a2..37372cd3d 100644 --- a/deeplabcut/pose_estimation_pytorch/data/utils.py +++ b/deeplabcut/pose_estimation_pytorch/data/utils.py @@ -276,27 +276,25 @@ def _compute_crop_bounds( def _extract_keypoints_and_bboxes( - anns: list[dict], - image_shape: tuple[int, int, int], - num_joints: int, - num_unique_bodyparts: int, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray]]: + anns, + image_shape, + num_joints, + num_unique_bodyparts, +): """ - Args: - anns: COCO-style annotations - image_shape: the (h, w, c) shape of the image for which to get annotations - num_joints: the number of joints in the annotations - - Returns: - keypoints, unique_keypoints, bboxes in xywh format, annotations_merged + Patch for DLC training when annotations are missing COCO metadata like: + area, category_id, iscrowd, individual_id. """ keypoints = [] original_bboxes = [] anns_to_merge = [] unique_keypoints = None + h, w = image_shape[:2] - for _i, annotation in enumerate(anns): + + for i, annotation in enumerate(anns): keypoints_individual = _annotation_to_keypoints(annotation, h, w) + if annotation["individual"] != "single": bbox_individual = annotation["bbox"] original_bboxes.append(bbox_individual) @@ -312,19 +310,40 @@ def _extract_keypoints_and_bboxes( original_bboxes = safe_stack(original_bboxes, (0, 4)) bboxes = _compute_crop_bounds(original_bboxes, image_shape, remove_empty=False) - # at least 1 visible joint to keep individuals + # Keep only individuals with at least one visible joint vis_mask = (keypoints[..., 2] > 0).any(axis=1) keypoints = keypoints[vis_mask] bboxes = bboxes[vis_mask] - keys_to_merge = ["area", "category_id", "iscrowd", "individual_id"] - anns_merged = {k: [] for k in keys_to_merge} - if len(anns_to_merge) > 0: - anns_merged = merge_list_of_dicts(anns_to_merge, keys_to_include=keys_to_merge) - anns_merged = {k: np.array(v)[vis_mask] for k, v in anns_merged.items()} - - if len(anns_merged["area"]) != len(keypoints): - raise ValueError(f"Missing area values! {anns_merged}, {keypoints.shape}") + def default_area(annotation): + if "area" in annotation: + return float(annotation["area"]) + + if "bbox" in annotation and len(annotation["bbox"]) == 4: + # bbox is assumed xywh + return float(annotation["bbox"][2]) * float(annotation["bbox"][3]) + + # fallback from visible keypoints + kp = np.asarray(annotation["keypoints"], dtype=float).reshape(-1, 3) + visible = kp[kp[:, 2] > 0, :2] + if len(visible) == 0: + return 0.0 + mins = visible.min(axis=0) + maxs = visible.max(axis=0) + wh = np.maximum(maxs - mins, 1.0) + return float(wh[0] * wh[1]) + + area = np.array([default_area(a) for a in anns_to_merge], dtype=float) + category_id = np.array([a.get("category_id", 1) for a in anns_to_merge], dtype=int) + iscrowd = np.array([a.get("iscrowd", 0) for a in anns_to_merge], dtype=int) + individual_id = np.array([a.get("individual_id", i) for i, a in enumerate(anns_to_merge)], dtype=int) + + anns_merged = { + "area": area[vis_mask], + "category_id": category_id[vis_mask], + "iscrowd": iscrowd[vis_mask], + "individual_id": individual_id[vis_mask], + } return keypoints, unique_keypoints, bboxes, anns_merged diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/external/__init__.py b/deeplabcut/pose_estimation_pytorch/models/detectors/external/__init__.py new file mode 100644 index 000000000..1a3f5a15d --- /dev/null +++ b/deeplabcut/pose_estimation_pytorch/models/detectors/external/__init__.py @@ -0,0 +1,11 @@ +from .base import EXTERNAL_DETECTORS, BaseExternalDetector, DetectionResult, PrecomputedDetectorRunner + +# Import all external detectors here to populate the registry +from .mock import MockExternalDetector + +__all__ = [ + "BaseExternalDetector", + "EXTERNAL_DETECTORS", + "DetectionResult", + "PrecomputedDetectorRunner", +] diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/external/base.py b/deeplabcut/pose_estimation_pytorch/models/detectors/external/base.py new file mode 100644 index 000000000..488807aaa --- /dev/null +++ b/deeplabcut/pose_estimation_pytorch/models/detectors/external/base.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Protocol, TypedDict + +import torch +import torch.nn as nn + +from deeplabcut.pose_estimation_pytorch.data.base import DetectorRunnerLike, Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxEntry, BBoxes, BBoxFormat, DetectorContext, EvalMode +from deeplabcut.pose_estimation_pytorch.registry import Registry, build_from_cfg + + +class DetectionResult(TypedDict, total=False): + boxes: torch.Tensor # FloatTensor [N, 4], absolute XYXY pixel coords + scores: torch.Tensor # FloatTensor [N] + labels: torch.Tensor # LongTensor [N] + # Optional future extensions: + # masks: torch.Tensor + # embeddings: torch.Tensor + # class_names: list[str] + + +class DetectorForwardLike(Protocol): + def forward( + self, + x: torch.Tensor | list[torch.Tensor], + targets: list[dict[str, torch.Tensor]] | None = None, + ) -> tuple[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]: ... + + +def _build_external_detector( + cfg: dict, + **kwargs, +) -> BaseExternalDetector: + """ + Builds an external detector from config. + + Unlike native DLC detectors, external detectors are assumed to be + inference-oriented and usually are not trained (but the pose estimation model may be trained on top of them). + As such, external detectors are not expected to have a training loop, and may not even have an optimizer or + snapshot loading, or target generation. + """ + detector: BaseExternalDetector = build_from_cfg(cfg, **kwargs) + return detector + + +EXTERNAL_DETECTORS = Registry("external_detectors", build_func=_build_external_detector) + + +class BaseExternalDetector(ABC, nn.Module): + """ + Base class for external / frozen detectors. + + These detectors expose a canonical inference API: + predict(images) -> list[DetectionResult] + + and a forward() shim for compatibility with DLC inference runners: + forward(images, targets=None) -> ({}, detections) + """ + + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def predict( + self, + images: list[torch.Tensor], + ) -> list[DetectionResult]: + """ + Run detection on a batch of images. + + Args: + images: + List of images, each typically a tensor of shape [C, H, W]. + + Returns: + One detection dict per image: + { + "boxes": FloatTensor[N, 4], # XYXY absolute pixel coords + "scores": FloatTensor[N], + "labels": LongTensor[N], + } + """ + raise NotImplementedError + + def forward( + self, + x: torch.Tensor | list[torch.Tensor], + targets: list[dict[str, torch.Tensor]] | None = None, + ) -> tuple[dict[str, torch.Tensor], list[DetectionResult]]: + """ + Compatibility shim so external detectors can be used with existing + inference-runner code that expects nn.Module.forward(). + + For inference-only external detectors, losses are always empty. + """ + if isinstance(x, torch.Tensor): + # Assume batched BCHW tensor -> convert to list[CHW] + images = list(x) + else: + images = x + + detections = self.predict(images) + return {}, detections + + +class PrecomputedDetectorRunner: + """ + Adapter that makes precomputed bbox entries behave like a detector runner. + + This is useful when you want to: + - train a top-down pose model using precomputed detector outputs + - run pose inference with saved bounding boxes + - avoid running a live detector at all + + It implements the minimal `inference(images, shelf_writer=None)` method expected + by the loader / dataset creation pathway. + """ + + def __init__( + self, + entries: list[BBoxEntry], + *, + target_format: BBoxFormat = "xywh", + validate_image_paths: bool = False, + ) -> None: + self.entries = list(entries) + self.target_format = target_format + self.validate_image_paths = validate_image_paths + + self._entries_by_path: dict[str, BBoxEntry] = {} + for entry in self.entries: + if entry.image_path is None: + continue + + key = self._normalize_path_for_compare(entry.image_path) + if key in self._entries_by_path: + raise ValueError(f"Duplicate precomputed bbox entry for image_path={entry.image_path}") + self._entries_by_path[key] = entry + + @staticmethod + def _normalize_path_for_compare(path: Path | str) -> str: + return Path(path).as_posix() + + @classmethod + def from_bboxes( + cls, + bboxes: BBoxes, + mode: EvalMode, + *, + target_format: BBoxFormat = "xywh", + validate_image_paths: bool = False, + ) -> PrecomputedDetectorRunner: + return cls( + entries=getattr(bboxes, mode), + target_format=target_format, + validate_image_paths=validate_image_paths, + ) + + @staticmethod + def _normalize_path_for_compare(path: Path | str) -> str: + return Path(path).as_posix().lower() + + @staticmethod + def _extract_image_path(item) -> Path | None: + if isinstance(item, tuple): + image = item[0] + else: + image = item + + if isinstance(image, (str, Path)): + return Path(image) + + return None + + def _find_entry_by_suffix(self, requested_path: Path) -> BBoxEntry | None: + requested = self._normalize_path_for_compare(requested_path) + + matches = [] + for entry in self.entries: + if entry.image_path is None: + continue + + entry_path = self._normalize_path_for_compare(entry.image_path) + + if requested.endswith(entry_path) or entry_path.endswith(requested): + matches.append(entry) + + if len(matches) == 1: + return matches[0] + + if len(matches) > 1: + raise ValueError( + f"Ambiguous precomputed bbox entries for requested image {requested_path}: " + f"{[m.image_path for m in matches]}" + ) + + return None + + def inference(self, images, shelf_writer=None) -> list[DetectorContext]: + """ + Return precomputed detector outputs aligned with the requested images. + + Args: + images: + Iterable of image inputs passed through DLC. + Supported elements: + - Path / str + - (Path / str, context_dict) + - np.ndarray / other non-path objects (order-only matching) + + shelf_writer: + Accepted for compatibility, ignored. + + Returns: + List of DLC detector contexts: + [{"bboxes": ..., "bbox_scores": ...}, ...] + """ + images = list(images) + requested_paths = [self._extract_image_path(item) for item in images] + + outputs: list[DetectorContext] = [] + + can_path_match = len(self._entries_by_path) > 0 and all(path is not None for path in requested_paths) + + if can_path_match: + for requested_path in requested_paths: + assert requested_path is not None + key = self._normalize_path_for_compare(requested_path) + + entry = self._entries_by_path.get(key) + + if entry is None: + # Optional useful fallback: match by filename/suffix when exact path differs. + entry = self._find_entry_by_suffix(requested_path) + + if entry is None: + raise ValueError( + f"No precomputed bbox entry found for requested image {requested_path}. " + f"Known entries include: {list(self._entries_by_path.keys())[:5]}" + ) + + outputs.append(entry.to_detector_context(target_format=self.target_format)) + + return outputs + + # Order-only fallback. + # This is necessary for ndarray inputs or precomputed entries without paths. + if self.validate_image_paths and any(path is not None for path in requested_paths): + raise ValueError( + "Cannot validate image paths because precomputed bbox entries do not contain " + "image_path metadata for path-based lookup." + ) + + if len(images) > len(self.entries): + raise ValueError( + f"Got {len(images)} images but only {len(self.entries)} precomputed bbox entries " + "are available for order-only matching." + ) + + for entry in self.entries[: len(images)]: + outputs.append(entry.to_detector_context(target_format=self.target_format)) + + return outputs + + +def precompute_detector_bboxes( + loader: Loader, + detector_runner: DetectorRunnerLike, + output_file: str | Path, + modes: tuple[str, ...] = ("train", "test"), + *, + bbox_format: str = "xywh", +) -> BBoxes: + """ + Run a detector runner on all images for the requested modes and save the results + to a BBoxes JSON artifact. + + The saved artifact is intended to be reused later for training a top-down pose + model without rerunning the detector. + """ + output_file = Path(output_file) + + result = {} + for mode in modes: + if hasattr(loader, "get_image_paths"): + image_paths = [Path(p) for p in loader.get_image_paths(mode)] # type: ignore[attr-defined] + else: + image_paths = [Path(p) for p in loader.image_filenames(mode)] + outputs = detector_runner.inference(image_paths) + + if len(outputs) != len(image_paths): + raise ValueError(f"Detector returned {len(outputs)} outputs for {len(image_paths)} {mode} images.") + + result[mode] = [ + BBoxEntry.from_detector_context( + out, + image_path=img_path, + bbox_format=bbox_format, + ) + for img_path, out in zip(image_paths, outputs, strict=False) + ] + + bboxes = BBoxes(**result) + bboxes.dump_json(output_file) + return bboxes + + +def build_precomputed_detector_runner_from_config( + model_cfg: dict, + mode: str, + *, + target_format: BBoxFormat = "xywh", + validate_image_paths: bool = False, +) -> PrecomputedDetectorRunner | None: + """ + Build a precomputed detector runner from model_cfg["data"]["precomputed_bboxes"]. + Returns None if no precomputed bbox file is configured. + """ + data_cfg = model_cfg.get("data", {}) + bbox_file = data_cfg.get("precomputed_bboxes") + if bbox_file is None: + return None + + bboxes = BBoxes.from_file(Path(bbox_file)) + return PrecomputedDetectorRunner.from_bboxes( + bboxes, + mode=mode, + target_format=target_format, + validate_image_paths=validate_image_paths, + ) diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/external/build.py b/deeplabcut/pose_estimation_pytorch/models/detectors/external/build.py new file mode 100644 index 000000000..e8d667db1 --- /dev/null +++ b/deeplabcut/pose_estimation_pytorch/models/detectors/external/build.py @@ -0,0 +1,64 @@ +import logging + +from deeplabcut.pose_estimation_pytorch.data.postprocessor import build_detector_postprocessor +from deeplabcut.pose_estimation_pytorch.data.preprocessor import build_bottom_up_preprocessor +from deeplabcut.pose_estimation_pytorch.data.transforms import build_transforms +from deeplabcut.pose_estimation_pytorch.models.detectors.external import EXTERNAL_DETECTORS +from deeplabcut.pose_estimation_pytorch.runners import build_inference_runner +from deeplabcut.pose_estimation_pytorch.task import Task + +logger = logging.getLogger(__name__) + + +def get_external_detector_inference_runner( + detector_cfg: dict, + batch_size: int, + device: str, + max_individuals: int, + color_mode: str, + transform=None, + inference_cfg=None, + min_bbox_score: float | None = None, +): + if transform is None: + transform = build_transforms({"scale_to_unit_range": True}) + + detector = EXTERNAL_DETECTORS.build(detector_cfg) + # to_device ? + try: + for param in detector.parameters(): + param.requires_grad = False + except (AttributeError, RuntimeError): + logger.warning( + "External detector does not have parameters that can be frozen. " + "Please review whether this is expected behavior for your detector." + ) + try: + detector.eval() + except AttributeError: + logger.warning( + "External detector does not have an eval() method. " + "Please review whether this is expected behavior for your detector." + ) + + runner = build_inference_runner( + task=Task.DETECT, + model=detector, + device=device, + snapshot_path=None, # always pre-trained + batch_size=batch_size, + # NOTE: the "bottom-up preprocessor" is a bit of a misnomer for this use case as + # this is a top-down pipeline, but what the preprocessor + # does here is to load the images, augment them, and convert them to tensors, + # which is what we need here as the input to the external detector. + preprocessor=build_bottom_up_preprocessor( + color_mode=color_mode, + transform=transform, + ), + postprocessor=build_detector_postprocessor( + max_individuals=max_individuals, + min_bbox_score=min_bbox_score, + ), + inference_cfg=inference_cfg, + ) + return runner diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/external/mock.py b/deeplabcut/pose_estimation_pytorch/models/detectors/external/mock.py new file mode 100644 index 000000000..cd1145721 --- /dev/null +++ b/deeplabcut/pose_estimation_pytorch/models/detectors/external/mock.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import torch + +from .base import EXTERNAL_DETECTORS, BaseExternalDetector + + +@EXTERNAL_DETECTORS.register_module +class MockExternalDetector(BaseExternalDetector): + """ + Simple detector for testing plumbing. + Returns one centered box per image. + """ + + def __init__(self, score: float = 0.9, label: int = 1) -> None: + super().__init__() + self.score = score + self.label = label + + def predict(self, images: list[torch.Tensor]): + outputs = [] + for image in images: + _, h, w = image.shape + box = torch.tensor([[w * 0.25, h * 0.25, w * 0.75, h * 0.75]], dtype=torch.float32) + score = torch.tensor([self.score], dtype=torch.float32) + label = torch.tensor([self.label], dtype=torch.long) + outputs.append( + { + "boxes": box, + "scores": score, + "labels": label, + } + ) + return outputs diff --git a/deeplabcut/pose_estimation_pytorch/runners/inference.py b/deeplabcut/pose_estimation_pytorch/runners/inference.py index 367685bb3..2272dc541 100644 --- a/deeplabcut/pose_estimation_pytorch/runners/inference.py +++ b/deeplabcut/pose_estimation_pytorch/runners/inference.py @@ -30,9 +30,11 @@ import deeplabcut.pose_estimation_pytorch.runners.shelving as shelving from deeplabcut.core.inferenceutils import calc_object_keypoint_similarity from deeplabcut.pose_estimation_pytorch.config.utils import update_config_by_dotpath +from deeplabcut.pose_estimation_pytorch.data.base import DetectorRunnerLike from deeplabcut.pose_estimation_pytorch.data.postprocessor import Postprocessor from deeplabcut.pose_estimation_pytorch.data.preprocessor import LoadImage, Preprocessor from deeplabcut.pose_estimation_pytorch.models.detectors import BaseDetector +from deeplabcut.pose_estimation_pytorch.models.detectors.external import BaseExternalDetector from deeplabcut.pose_estimation_pytorch.models.model import PoseModel from deeplabcut.pose_estimation_pytorch.runners.base import ModelType, Runner from deeplabcut.pose_estimation_pytorch.runners.dynamic_cropping import ( @@ -41,6 +43,7 @@ ) from deeplabcut.pose_estimation_pytorch.task import Task +DetectorModel = BaseDetector | BaseExternalDetector # NOTE @deruyter92 2026-04-28: AMD GPUs with DirectML inference mode currently do not # support torch.inference_mode, which is stricter than torch.no_grad. The ENV # variable is used to conditionally use torch.no_grad instead. See PR #3295. @@ -977,10 +980,10 @@ def _merge_conditions(self, bu_cond: np.ndarray) -> np.ndarray: return cond_pose[: len(self._idx_to_id)] -class DetectorInferenceRunner(InferenceRunner[BaseDetector]): +class DetectorInferenceRunner(InferenceRunner[DetectorModel]): """Runner for object detection inference.""" - def __init__(self, model: BaseDetector, **kwargs): + def __init__(self, model: DetectorModel, **kwargs): """ Args: model: The detector to use for inference. @@ -1020,6 +1023,246 @@ def predict(self, inputs: torch.Tensor, **kwargs) -> list[dict[str, dict[str, np return predictions +class DetectorToPoseInferenceRunner: + """ + Compose a detector runner with a top-down pose runner. + + Expected flow: + input image(s) + -> detector_runner.inference(...) + -> inject detector boxes into context["bboxes"] + -> pose_runner.inference(...) + + This is intentionally simple: + - it does not modify the pose runner internals + - it works with any detector_runner that satisfies DetectorRunnerLike + - it works with live detector runners and precomputed detector runners + """ + + def __init__( + self, + pose_runner, + detector_runner: DetectorRunnerLike, + *, + max_individuals: int | None = None, + num_joints: int | None = 17, + num_unique_bodyparts: int | None = 0, + fill_value: float = np.nan, + ) -> None: + self.pose_runner = pose_runner + self.detector_runner = detector_runner + self.max_individuals = None if max_individuals is None else max(1, max_individuals) + self.num_joints = 17 if num_joints is None else max(1, num_joints) + self.num_unique_bodyparts = 0 if num_unique_bodyparts is None else max(0, num_unique_bodyparts) + self.fill_value = fill_value + + @staticmethod + def _split_input_and_context( + item: str | Path | np.ndarray | tuple[str | Path | np.ndarray, dict[str, Any]], + ) -> tuple[str | Path | np.ndarray, dict[str, Any]]: + """ + Normalize an inference item into (image, context). + + Supported inputs: + - "path/to/image.png" + - Path("path/to/image.png") + - np.ndarray image + - (image, context_dict) + """ + if isinstance(item, tuple): + image, context = item + return image, dict(context) + return item, {} + + @staticmethod + def _normalize_detector_output(det: dict[str, Any]) -> tuple[np.ndarray, np.ndarray]: + """ + Convert detector output into the context format expected by TopDownCrop. + + Required: + - det["bboxes"] shaped [N, 4] in xywh format + + Optional: + - det["bbox_scores"] shaped [N] + """ + bboxes = np.asarray(det.get("bboxes", np.zeros((0, 4), dtype=np.float32)), dtype=np.float32).reshape(-1, 4) + + bbox_scores = np.asarray( + det.get("bbox_scores", np.ones((len(bboxes),), dtype=np.float32)), + dtype=np.float32, + ).reshape(-1) + + if len(bbox_scores) != len(bboxes): + raise ValueError( + f"Expected one bbox score per bbox, but got {len(bbox_scores)} scores for {len(bboxes)} boxes." + ) + + return bboxes, bbox_scores + + def _select_and_order_boxes( + self, + det: dict[str, Any], + context: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """ + Default strategy: + - sort by descending score + - keep at most max_individuals + + Future extension: + - if context contains reference boxes, reorder using IoU matching + """ + bboxes = np.asarray(det.get("bboxes", np.zeros((0, 4))), dtype=np.float32).reshape(-1, 4) + + if "bbox_scores" in det: + bbox_scores = np.asarray(det["bbox_scores"], dtype=np.float32).reshape(-1) + else: + bbox_scores = np.ones((len(bboxes),), dtype=np.float32) + + if len(bbox_scores) != len(bboxes): + raise ValueError( + f"Expected one bbox score per bbox, got {len(bbox_scores)} scores for {len(bboxes)} boxes." + ) + + if len(bboxes) == 0: + return bboxes, bbox_scores + + # Keep deterministic ordering: highest confidence first + # only if max_individuals is set + if self.max_individuals is not None: + order = np.argsort(-bbox_scores) + bboxes = bboxes[order] + bbox_scores = bbox_scores[order] + + # Only truncate if explicitly requested. + if self.max_individuals is not None: + bboxes = bboxes[: self.max_individuals] + bbox_scores = bbox_scores[: self.max_individuals] + + return bboxes.astype(np.float32, copy=False), bbox_scores.astype(np.float32, copy=False) + + @staticmethod + def _pad_first_dim(arr: np.ndarray, target_n: int, fill_value=np.nan) -> np.ndarray: + arr = np.asarray(arr) + + if arr.shape[0] == target_n: + return arr + + if arr.shape[0] > target_n: + return arr[:target_n] + + if not np.issubdtype(arr.dtype, np.floating): + arr = arr.astype(np.float32) + + pad_shape = (target_n - arr.shape[0],) + arr.shape[1:] + pad = np.full(pad_shape, fill_value, dtype=arr.dtype) + return np.concatenate([arr, pad], axis=0) + + def _empty_prediction(self, last_dim: int = 3) -> dict[str, np.ndarray]: + # If max_individuals is unspecified, an image with no detections should emit + # zero pose rows. If it is specified, emit a fixed-size padded empty output. + n_individuals = 0 if self.max_individuals is None else self.max_individuals + + pred = { + "bodyparts": np.full( + (n_individuals, self.num_joints, last_dim), + self.fill_value, + dtype=np.float32, + ) + } + + if self.num_unique_bodyparts > 0: + pred["unique_bodyparts"] = np.full( + (1, self.num_unique_bodyparts, last_dim), + self.fill_value, + dtype=np.float32, + ) + + return pred + + def _normalize_prediction( + self, + pred: dict[str, Any] | None, + *, + last_dim_hint: int = 3, + ) -> dict[str, np.ndarray]: + if pred is None or "bodyparts" not in pred: + return self._empty_prediction(last_dim=last_dim_hint) + + pred = dict(pred) + + bodyparts = np.asarray(pred["bodyparts"]) + if bodyparts.ndim != 3: + raise ValueError(f"Unexpected bodyparts shape: {bodyparts.shape}") + + last_dim = bodyparts.shape[-1] + if self.max_individuals is not None: + pred["bodyparts"] = self._pad_first_dim( + bodyparts, + self.max_individuals, + fill_value=self.fill_value, + ) + else: + pred["bodyparts"] = bodyparts + + if self.num_unique_bodyparts > 0: + if "unique_bodyparts" in pred: + ub = np.asarray(pred["unique_bodyparts"]) + if ub.ndim == 2: + ub = ub[None, ...] + pred["unique_bodyparts"] = self._pad_first_dim(ub, 1, fill_value=self.fill_value) + else: + pred["unique_bodyparts"] = np.full( + (1, self.num_unique_bodyparts, last_dim), + self.fill_value, + dtype=np.float32, + ) + + return pred + + @_inference_mode_decorator + def inference( + self, + images: (Iterable[str | Path | np.ndarray] | Iterable[tuple[str | Path | np.ndarray, dict[str, Any]]]), + shelf_writer: shelving.ShelfWriter | None = None, + ): + images = list(images) + + # Split once so we can preserve and copy incoming contexts. + split_items = [self._split_input_and_context(item) for item in images] + raw_images = [image for image, _ in split_items] + incoming_contexts = [context for _, context in split_items] + + # Detector sees raw image inputs only. The wrapper owns context enrichment. + detections = self.detector_runner.inference(raw_images) + + if len(detections) != len(raw_images): + raise ValueError(f"Detector returned {len(detections)} outputs for {len(raw_images)} input images.") + + enriched_inputs = [] + + for image, context, det in zip(raw_images, incoming_contexts, detections, strict=False): + # Copy context so caller-owned dictionaries are not mutated. + context = dict(context) + + bboxes, bbox_scores = self._select_and_order_boxes(det, context=context) + + context["bboxes"] = bboxes + context["bbox_scores"] = bbox_scores + context["detector_output"] = det + + enriched_inputs.append((image, context)) + + # The wrapped pose runner owns: + # - top-down preprocessing + # - pose prediction + # - postprocessing + # - shelf writing + # + # The wrapper only injects detector context and returns the pose runner output. + return self.pose_runner.inference(enriched_inputs, shelf_writer=shelf_writer) + + def build_inference_runner( task: Task, model: nn.Module, @@ -1031,8 +1274,9 @@ def build_inference_runner( dynamic: DynamicCropper | None = None, load_weights_only: bool | None = None, inference_cfg: InferenceConfig | dict | None = None, + detector_runner: DetectorRunnerLike | None = None, **kwargs, -) -> InferenceRunner: +) -> InferenceRunner | DetectorToPoseInferenceRunner: """Build a runner object according to a pytorch configuration file. Args: @@ -1090,6 +1334,17 @@ def build_inference_runner( dynamic = None if task == Task.COND_TOP_DOWN: - return CTDInferenceRunner(**kwargs) + runner = CTDInferenceRunner(**kwargs) + else: + runner = PoseInferenceRunner(dynamic=dynamic, **kwargs) + + if detector_runner is not None and task == Task.TOP_DOWN: + return DetectorToPoseInferenceRunner( + pose_runner=runner, + detector_runner=detector_runner, + max_individuals=kwargs.get("max_individuals", 1), + num_joints=kwargs.get("num_joints"), + num_unique_bodyparts=kwargs.get("num_unique_bodyparts", 0), + ) - return PoseInferenceRunner(dynamic=dynamic, **kwargs) + return runner diff --git a/deeplabcut/pose_estimation_pytorch/runners/train.py b/deeplabcut/pose_estimation_pytorch/runners/train.py index 6e87751d5..1f7ec6385 100644 --- a/deeplabcut/pose_estimation_pytorch/runners/train.py +++ b/deeplabcut/pose_estimation_pytorch/runners/train.py @@ -173,11 +173,11 @@ def _compute_epoch_metrics(self) -> dict[str, float]: raise NotImplementedError def _gpu_usage_str(self) -> str: - if not torch.cuda.is_available(): - return "" - used = torch.cuda.memory_reserved() / 1024**2 - total = torch.cuda.get_device_properties(0).total_memory / 1024**2 - return f", GPU: {used:.1f}/{total:.1f} MiB" + if "cuda" in str(self.device).lower() and torch.cuda.is_available(): + used = torch.cuda.memory_reserved(self.device) / 1024**2 + total = torch.cuda.get_device_properties(self.device).total_memory / 1024**2 + return f", GPU: {used:.1f}/{total:.1f} MiB" + return "" def fit( self, @@ -737,15 +737,12 @@ def build_optimizer( model: nn.Module, optimizer_config: dict, ) -> torch.optim.Optimizer: - """Builds an optimizer from a configuration. + """Builds an optimizer from a configuration.""" + optim_cls = getattr(torch.optim, optimizer_config["type"]) - Args: - model: The model to optimize. - optimizer_config: The configuration for the optimizer. + params = [p for p in model.parameters() if p.requires_grad] + if len(params) == 0: + raise ValueError("Cannot build optimizer: model has no trainable parameters.") - Returns: - The optimizer for the model built according to the given configuration. - """ - optim_cls = getattr(torch.optim, optimizer_config["type"]) - optimizer = optim_cls(params=model.parameters(), **optimizer_config["params"]) + optimizer = optim_cls(params=params, **optimizer_config["params"]) return optimizer diff --git a/examples/detector_test_full_api.py b/examples/detector_test_full_api.py new file mode 100644 index 000000000..9f98bfc6a --- /dev/null +++ b/examples/detector_test_full_api.py @@ -0,0 +1,873 @@ +""" +Synthetic end-to-end example for the external-detector / precomputed-bbox workflow +in DeepLabCut PyTorch top-down pose estimation. + +If you are mostly interested in the process of using a detector see the "Usage" section below. + +This example is intentionally focused and highly documented. It demonstrates the +"offline boxes" workflow, which is typically the easiest path to integrate custom +external detectors and curate their outputs before training a DLC pose model. + +What this script does +--------------------- +1. Creates a minimal, valid DLC-style multi-animal project on disk with synthetic data: + - black RGB frames + - one white square per frame + - four annotated keypoints (one at each corner) +2. Builds a real ``DLCLoader`` on top of that project. +3. Runs a tiny detector adapter to generate per-image bounding boxes. +4. Saves those boxes via ``precompute_detector_bboxes(...)`` as a JSON artifact. +5. Creates/updates the PyTorch pose config so training uses those precomputed boxes. +6. Verifies that ``DLCLoader.create_dataset(..., detector_runner=...)`` picks up the + detector boxes before training. +7. Calls the real high-level ``train_network(...)`` API while patching only: + - the pose-model builder (to use a tiny demo model), and + - the transform builder (to keep the example deterministic and lightweight). +8. Optionally writes the synthetic frames into a short video and runs + ``video_inference(...)`` using per-frame precomputed bounding-box context. + +Important scope note +-------------------- +This script is intended as: +- a runnable proof-of-concept for the new external / precomputed detector path, +- a clearly documented example for hackathon participants, +- and a strong integration test blueprint. + +It is *not* intended as a realistic training recipe for production-quality models. +The tiny pose model used here is only meant to prove that the end-to-end plumbing +works with the real DLC APIs. + +Usage +----- +Run as a script: + + python detector_test_full_api.py --output-dir /tmp/dlc_external_demo + +If ``--output-dir`` is omitted, a temporary directory is created automatically. +Add ``--no-inference`` to skip the video inference step. +""" + +from __future__ import annotations + +import argparse +import copy +import pickle +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from PIL import Image + +import deeplabcut.core.config as config_utils +import deeplabcut.utils.auxiliaryfunctions as af +from deeplabcut.core.engine import Engine +from deeplabcut.pose_estimation_pytorch.apis.training import train_network +from deeplabcut.pose_estimation_pytorch.config.make_pose_config import ( + _yaml_safe_value, + make_pytorch_pose_config, +) +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxes +from deeplabcut.pose_estimation_pytorch.data.dlcloader import ( + DLCLoader, + build_dlc_dataframe_columns, +) +from deeplabcut.pose_estimation_pytorch.models.detectors.external.base import ( + build_precomputed_detector_runner_from_config, + precompute_detector_bboxes, +) +from deeplabcut.pose_estimation_pytorch.task import Task + +# ----------------------------------------------------------------------------- +# Lightweight helpers used to keep the demo deterministic and robust +# ----------------------------------------------------------------------------- + + +class IdentityTopDownTransform: + """ + Minimal transform object matching the contract expected by PoseDataset. + + It preserves image / keypoints / bboxes exactly as given, and always returns a + dict containing those keys so dataset.py does not fail on missing 'bboxes'. + """ + + def __call__(self, **kwargs): + transformed = dict(kwargs) + transformed.setdefault("image", None) + transformed.setdefault("keypoints", []) + transformed.setdefault("bboxes", []) + return transformed + + def __repr__(self): + return "IdentityTopDownTransform()" + + +# ----------------------------------------------------------------------------- +# Synthetic data helpers +# ----------------------------------------------------------------------------- + + +BODYPARTS = ["tl", "tr", "br", "bl"] +INDIVIDUALS = ["square"] + + +@dataclass +class SyntheticFrame: + image: np.ndarray + bbox_xywh: np.ndarray + keypoints_xyv: np.ndarray + rel_index: tuple[str, str, str] + abs_path: Path + + +@dataclass +class SyntheticProject: + project_root: Path + config_path: Path + pose_config_path: Path + precomputed_bboxes_path: Path + frames: list[SyntheticFrame] + + +class SquareThresholdDetectorRunner: + """ + Tiny stand-in for an external detector runner. + + It implements the minimal detector-runner contract expected by the external + detector / precomputed bbox workflow: + + inference(images, shelf_writer=None) + -> list[{"bboxes": ..., "bbox_scores": ...}] + + The detector simply thresholds non-zero pixels and returns one enclosing bbox per + image in ``xywh`` format. + """ + + def __init__(self, threshold: int = 1, score: float = 0.99): + self.threshold = threshold + self.score = float(score) + + @staticmethod + def _load_image(item: str | Path | np.ndarray | tuple[Any, dict[str, Any]]) -> np.ndarray: + if isinstance(item, tuple): + item = item[0] + if isinstance(item, np.ndarray): + return item + return np.asarray(Image.open(item).convert("RGB")) + + def inference(self, images, shelf_writer=None): + outputs = [] + for item in images: + image = self._load_image(item) + mask = image[..., 0] >= self.threshold + ys, xs = np.where(mask) + if len(xs) == 0 or len(ys) == 0: + bboxes = np.zeros((0, 4), dtype=np.float32) + scores = np.zeros((0,), dtype=np.float32) + else: + x0 = float(xs.min()) + y0 = float(ys.min()) + x1 = float(xs.max()) + y1 = float(ys.max()) + # inclusive pixel extent -> width/height = max-min+1 + bbox = np.array([[x0, y0, x1 - x0 + 1.0, y1 - y0 + 1.0]], dtype=np.float32) + score = np.array([self.score], dtype=np.float32) + bboxes = bbox + scores = score + + outputs.append( + { + "bboxes": bboxes, + "bbox_scores": scores, + } + ) + return outputs + + +class TinyCornerPoseModel(nn.Module): + """ + Minimal trainable pose model for one individual with four keypoints. + + This model is deliberately tiny. It only serves to make the high-level training + and inference paths run with a lightweight, deterministic model while still + exercising: + - the real DLCLoader, + - the real create_dataset(..., detector_runner=...), + - the real train_network(...) API, + - snapshot saving/loading, + - and video_inference(...) with precomputed bbox context. + """ + + def __init__(self): + super().__init__() + self.backbone = nn.Identity() + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 12) # 4 keypoints * (x, y, conf) + + def forward(self, x, cond_kpts=None): + pooled = self.pool(x).flatten(1) # [B, 3] + pred = self.fc(pooled).reshape(len(x), 1, 4, 3) + pred[..., 2] = torch.sigmoid(pred[..., 2]) + return {"pred_keypoints": pred} + + def get_target(self, outputs, annotations): + return annotations["keypoints"].float().to(outputs["pred_keypoints"].device) + + def get_loss(self, outputs, target): + pred = outputs["pred_keypoints"] + loss_xy = ((pred[..., :2] - target[..., :2]) ** 2).mean() + loss_conf = ((pred[..., 2] - 1.0) ** 2).mean() + total = loss_xy + 0.1 * loss_conf + return { + "total_loss": total, + "loss_xy": loss_xy, + "loss_conf": loss_conf, + } + + def get_predictions(self, outputs): + return { + "bodypart": { + "poses": outputs["pred_keypoints"], + } + } + + +# ----------------------------------------------------------------------------- +# Project construction +# ----------------------------------------------------------------------------- + + +def make_square_image( + image_size: tuple[int, int] = (128, 128), + square_xywh: tuple[int, int, int, int] = (32, 40, 24, 24), +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Create one synthetic RGB frame with a white square on a black background. + + Returns: + image: uint8 array [H, W, 3] + bbox_xywh: float32 array [4] + keypoints_xyv: float32 array [4, 3] + """ + h, w = image_size + x, y, bw, bh = square_xywh + + image = np.zeros((h, w, 3), dtype=np.uint8) + image[y : y + bh, x : x + bw] = 255 + + keypoints = np.array( + [ + [x, y, 2.0], + [x + bw - 1, y, 2.0], + [x + bw - 1, y + bh - 1, 2.0], + [x, y + bh - 1, 2.0], + ], + dtype=np.float32, + ) + bbox = np.array([x, y, bw, bh], dtype=np.float32) + return image, bbox, keypoints + + +def _save_rgb_png(image: np.ndarray, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(image).save(path) + + +def _create_project_config(project_root: Path) -> tuple[dict, Any]: + """ + Build a minimal multi-animal DLC project config with a single individual. + + We intentionally use the multi-animal pickle dataset pathway because it is much + easier to synthesize than the legacy .mat single-animal format. + """ + cfg_file, yaml_file = af.create_config_template(multianimal=True) + yaml_file.width = 10_000 + + videos_dir = project_root / "videos" + videos_dir.mkdir(parents=True, exist_ok=True) + dummy_video = videos_dir / "dummy.mp4" + dummy_video.write_bytes(b"") + + cfg_file["Task"] = "synthetic-square" + cfg_file["scorer"] = "synthetic" + cfg_file["date"] = "2026-04-09" + cfg_file["project_path"] = project_root.as_posix() + cfg_file["video_sets"] = {dummy_video.as_posix(): {"crop": "0, 128, 0, 128"}} + + cfg_file["multianimalproject"] = True + cfg_file["individuals"] = copy.deepcopy(INDIVIDUALS) + cfg_file["multianimalbodyparts"] = copy.deepcopy(BODYPARTS) + cfg_file["uniquebodyparts"] = [] + cfg_file["bodyparts"] = "MULTI!" + + cfg_file["TrainingFraction"] = [0.75] + cfg_file["iteration"] = 0 + cfg_file["snapshotindex"] = -1 + + return cfg_file, yaml_file + + +def _build_collected_data_dataframe( + scorer: str, + frames: list[SyntheticFrame], +) -> pd.DataFrame: + from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters + + params = PoseDatasetParameters( + bodyparts=BODYPARTS, + unique_bpts=[], + individuals=INDIVIDUALS, + with_center_keypoints=False, + color_mode="RGB", + top_down_crop_size=(32, 32), + top_down_crop_margin=0, + top_down_crop_with_context=True, + ) + + columns = build_dlc_dataframe_columns(scorer, params, with_likelihood=False) + + rows = [] + index = [] + for frame in frames: + xy = frame.keypoints_xyv[:, :2].reshape(1, len(BODYPARTS), 2) + rows.append(xy.reshape(-1)) + index.append(frame.rel_index) + + df = pd.DataFrame( + data=np.stack(rows, axis=0), + index=pd.MultiIndex.from_tuples(index), + columns=columns, + ) + return df.sort_index(axis=0) + + +def _build_dataset_pickle_entries(frames: list[SyntheticFrame]) -> list[dict[str, Any]]: + entries = [] + for frame in frames: + joints = np.array( + [[i, kp[0], kp[1]] for i, kp in enumerate(frame.keypoints_xyv)], + dtype=np.float32, + ) + h, w = frame.image.shape[:2] + entries.append( + { + "image": frame.rel_index, + "size": (3, h, w), + "joints": { + 0: joints, + }, + } + ) + return entries + + +def _ensure_loader_get_image_paths() -> None: + """ + Compatibility shim for versions where precompute_detector_bboxes(...) expects a + loader.get_image_paths(...) method but Loader only exposes image_filenames(...). + """ + if not hasattr(DLCLoader, "get_image_paths"): + DLCLoader.get_image_paths = DLCLoader.image_filenames + + +# ----------------------------------------------------------------------------- +# POSE CONFIG +# ----------------------------------------------------------------------------- + + +def _write_or_update_pose_config( + project_cfg: dict, + pose_config_path: Path, + precomputed_bboxes: str | Path, + *, + crop_size: tuple[int, int] = (32, 32), + epochs: int = 1, + batch_size: int = 1, +) -> dict: + """ + Create a PyTorch pose config for the external / precomputed detector workflow, + then patch it down to a tiny, CPU-friendly demo setup. + """ + pose_config_path.parent.mkdir(parents=True, exist_ok=True) + + pose_cfg = make_pytorch_pose_config( + project_config=project_cfg, + pose_config_path=pose_config_path, + top_down=True, + detector_mode="external", + save=True, + precomputed_bboxes=precomputed_bboxes, + bbox_source="detection_bbox", + external_detector_metadata={ + "name": "SquareThresholdDetectorRunner", + "kind": "synthetic_demo", + }, + ) + + # Patch the config down to a minimal, fast, CPU-friendly training setup. + pose_cfg.setdefault("metadata", {}) + pose_cfg["metadata"]["bodyparts"] = copy.deepcopy(BODYPARTS) + pose_cfg["metadata"]["unique_bodyparts"] = [] + pose_cfg["metadata"]["individuals"] = copy.deepcopy(INDIVIDUALS) + + pose_cfg["method"] = "td" + pose_cfg["net_type"] = pose_cfg.get("net_type", "resnet_50") + pose_cfg["color_mode"] = "RGB" + pose_cfg["with_center_keypoints"] = False + + pose_cfg.setdefault("model", {}) + pose_cfg["model"]["type"] = "TinyCornerPoseModel" + + pose_cfg.setdefault("data", {}) + pose_cfg["data"]["bbox_source"] = "detection_bbox" + pose_cfg["data"]["precomputed_bboxes"] = Path(precomputed_bboxes).as_posix() + pose_cfg["data"]["bbox_validate_image_paths"] = False + pose_cfg["data"].setdefault("bbox_match_iou_threshold", 0.1) + pose_cfg["data"].setdefault("bbox_fallback_to_gt", False) + pose_cfg["data"].setdefault("bbox_margin", 0) + pose_cfg["data"]["colormode"] = "RGB" + pose_cfg["data"].setdefault("train", {}) + pose_cfg["data"].setdefault("inference", {}) + pose_cfg["data"]["train"].setdefault("top_down_crop", {}) + pose_cfg["data"]["train"]["top_down_crop"].update( + { + "width": int(crop_size[0]), + "height": int(crop_size[1]), + "margin": 0, + "crop_with_context": True, + } + ) + pose_cfg["data"]["inference"].setdefault("top_down_crop", {}) + pose_cfg["data"]["inference"]["top_down_crop"].update( + { + "width": int(crop_size[0]), + "height": int(crop_size[1]), + "margin": 0, + "crop_with_context": True, + } + ) + + pose_cfg.setdefault("train_settings", {}) + pose_cfg["train_settings"].update( + { + "seed": 0, + "epochs": int(epochs), + "batch_size": int(batch_size), + "dataloader_workers": 0, + "dataloader_pin_memory": False, + "display_iters": 1, + } + ) + + pose_cfg.setdefault("runner", {}) + pose_cfg["runner"]["optimizer"] = { + "type": "SGD", + "params": {"lr": 0.1}, + } + # Skip evaluation in this demo to keep it focused on the training path. + pose_cfg["runner"]["eval_interval"] = 999 + pose_cfg["runner"]["snapshots"] = { + "max_snapshots": 1, + "save_epochs": 1, + "save_optimizer_state": True, + } + + # Compatibility stub: current train_network() still expects detector.train_settings.epochs + pose_cfg.setdefault("detector", {}) + pose_cfg["detector"].setdefault("train_settings", {}) + pose_cfg["detector"]["train_settings"]["epochs"] = 0 + + pose_cfg = _yaml_safe_value(pose_cfg) + config_utils.write_config(pose_config_path, pose_cfg, overwrite=True) + return pose_cfg + + +def make_synthetic_square_dlc_project( + output_dir: str | Path, + *, + num_frames: int = 4, + image_size: tuple[int, int] = (128, 128), + crop_size: tuple[int, int] = (32, 32), + shuffle: int = 1, +) -> SyntheticProject: + """ + Create a minimal, valid DLC-style project on disk using real image files, + CollectedData.h5, dataset split pickle, dataset pickle and a PyTorch pose config. + """ + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # 1) project config + project_cfg, yaml_file = _create_project_config(output_dir) + config_path = output_dir / "config.yaml" + with open(config_path, "w", encoding="utf-8") as f: + yaml_file.dump(project_cfg, f) + + # 2) synthetic frames on disk + frames: list[SyntheticFrame] = [] + placements = [ + (24, 24, 20, 20), + (64, 16, 24, 24), + (20, 72, 28, 16), + (72, 72, 18, 26), + ] + placements = placements[:num_frames] + + for i, square in enumerate(placements): + image, bbox, keypoints = make_square_image(image_size=image_size, square_xywh=square) + rel_index = ("labeled-data", "synthetic-square", f"img{i:03d}.png") + abs_path = output_dir.joinpath(*rel_index) + _save_rgb_png(image, abs_path) + frames.append( + SyntheticFrame( + image=image, + bbox_xywh=bbox, + keypoints_xyv=keypoints, + rel_index=rel_index, + abs_path=abs_path, + ) + ) + + # 3) CollectedData_.h5 + trainset_dir = output_dir / af.get_training_set_folder(project_cfg) + trainset_dir.mkdir(parents=True, exist_ok=True) + collected_path = trainset_dir / f"CollectedData_{project_cfg['scorer']}.h5" + collected_df = _build_collected_data_dataframe(project_cfg["scorer"], frames) + collected_df.to_hdf(collected_path, key="df_with_missing") + + # 4) DLC multi-animal dataset pickle + train_frac = int(100 * project_cfg["TrainingFraction"][0]) + dataset_prefix = f"{project_cfg['Task']}_{project_cfg['scorer']}{train_frac}shuffle{shuffle}" + dataset_pickle_path = trainset_dir / f"{dataset_prefix}.pickle" + with open(dataset_pickle_path, "wb") as f: + pickle.dump(_build_dataset_pickle_entries(frames), f) + + # 5) split pickle consumed by DLCLoader.load_split(...) + # meta[1] -> train ids, meta[2] -> test ids + train_ids = list(range(max(1, len(frames) - 1))) + test_ids = [len(frames) - 1] + split_pickle_path = trainset_dir / f"Documentation_data-{project_cfg['Task']}_{train_frac}shuffle{shuffle}.pickle" + with open(split_pickle_path, "wb") as f: + pickle.dump((None, train_ids, test_ids), f) + + # 6) model folder / PyTorch config path + model_folder = af.get_model_folder( + project_cfg["TrainingFraction"][0], + shuffle, + project_cfg, + engine=Engine.PYTORCH, + modelprefix="", + ) + pose_config_path = output_dir / model_folder / "train" / Engine.PYTORCH.pose_cfg_name + precomputed_bboxes_path = output_dir / model_folder / "train" / "precomputed_bboxes.json" + + _write_or_update_pose_config( + project_cfg=project_cfg, + pose_config_path=pose_config_path, + precomputed_bboxes=precomputed_bboxes_path, + crop_size=crop_size, + epochs=1, + batch_size=1, + ) + + return SyntheticProject( + project_root=output_dir, + config_path=config_path, + pose_config_path=pose_config_path, + precomputed_bboxes_path=precomputed_bboxes_path, + frames=frames, + ) + + +# ----------------------------------------------------------------------------- +# Workflow helpers +# ----------------------------------------------------------------------------- + + +def generate_precomputed_detector_boxes(project: SyntheticProject, shuffle: int = 1) -> BBoxes: + """ + Canonical external-detector workflow step: + 1. build a real DLCLoader on the project + 2. run a detector runner + 3. save the results as a BBoxes JSON artifact + """ + _ensure_loader_get_image_paths() + + loader = DLCLoader(config=project.config_path, shuffle=shuffle, trainset_index=0) + detector_runner = SquareThresholdDetectorRunner() + + bboxes = precompute_detector_bboxes( + loader=loader, + detector_runner=detector_runner, + output_file=project.precomputed_bboxes_path, + modes=("train", "test"), + bbox_format="xywh", + ) + return bboxes + + +def verify_loader_uses_precomputed_boxes(project: SyntheticProject, shuffle: int = 1) -> None: + """ + Pre-flight check before training: + prove that the real DLCLoader picks up the saved precomputed detector boxes and + rewrites top-down annotation bboxes accordingly. + """ + loader = DLCLoader(config=project.config_path, shuffle=shuffle, trainset_index=0) + runner = build_precomputed_detector_runner_from_config( + loader.model_cfg, + mode="train", + target_format="xywh", + validate_image_paths=False, + ) + if runner is None: + raise RuntimeError("Failed to build a precomputed detector runner from the pose config.") + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=runner, + ) + + # Check the first training-frame annotation bbox against the known synthetic square. + expected = np.asarray(project.frames[0].bbox_xywh, dtype=np.float32) + found = np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32) + np.testing.assert_allclose(found, expected, atol=1e-5) + + +# ----------------------------------------------------------------------------- +# TRAINING +# ----------------------------------------------------------------------------- + + +def run_train_network_demo(project: SyntheticProject, shuffle: int = 1) -> TinyCornerPoseModel: + """ + Run the real high-level train_network(...) API while patching only: + - PoseModel.build(...) -> tiny trainable demo model + - build_transforms(...) -> identity transform preserving bbox/keypoint contract + + Returns the trained tiny model instance so callers can inspect parameter changes. + """ + import deeplabcut.pose_estimation_pytorch.apis.training as training_api + + tiny_model = TinyCornerPoseModel() + before = {name: p.detach().cpu().clone() for name, p in tiny_model.named_parameters()} + + with ( + patch.object( + training_api.PoseModel, + "build", + side_effect=lambda *args, **kwargs: tiny_model, + ), + patch.object( + training_api, + "build_transforms", + side_effect=lambda cfg: IdentityTopDownTransform(), + ), + ): + train_network( + config=project.config_path, + shuffle=shuffle, + trainingsetindex=0, + device="cpu", + ) + + changed = [name for name, p in tiny_model.named_parameters() if not torch.equal(before[name], p.detach().cpu())] + if len(changed) == 0: + raise AssertionError("Expected at least one model parameter to change during train_network(...).") + + return tiny_model + + +# ----------------------------------------------------------------------------- +# INFERENCE +# ----------------------------------------------------------------------------- + + +def write_synthetic_video( + project: SyntheticProject, + *, + video_name: str = "synthetic_video.mp4", + fps: int = 5, +) -> Path: + import cv2 + + video_path = project.project_root / video_name + h, w = project.frames[0].image.shape[:2] + + writer = cv2.VideoWriter( + str(video_path), + cv2.VideoWriter_fourcc(*"mp4v"), + fps, + (w, h), + ) + if not writer.isOpened(): + raise RuntimeError(f"Failed to open video writer for {video_path}") + + for frame in project.frames: + # OpenCV expects BGR + bgr = frame.image[..., ::-1].copy() + writer.write(bgr) + + writer.release() + return video_path + + +def build_video_context_from_detector(project: SyntheticProject) -> list[dict[str, np.ndarray]]: + """ + Run the same tiny detector on the synthetic frame arrays and build per-frame + context compatible with VideoIterator / video_inference. + """ + detector = SquareThresholdDetectorRunner() + outputs = detector.inference([f.image for f in project.frames]) + return outputs + + +def run_video_inference_demo(project: SyntheticProject, shuffle: int = 1): + """ + Run video_inference(...) on a synthetic video using per-frame precomputed bbox + context. This demonstrates the cleanest current inference story for the external / + offline boxes workflow. + """ + import deeplabcut.pose_estimation_pytorch.apis.utils as api_utils + import deeplabcut.pose_estimation_pytorch.apis.videos as videos_api + + loader = DLCLoader(config=project.config_path, shuffle=shuffle, trainset_index=0) + + # Get the most recent pose snapshot + snapshots = api_utils.get_model_snapshots(-1, loader.model_folder, loader.pose_task) + if len(snapshots) == 0: + raise RuntimeError("No pose snapshot found after training.") + snapshot = snapshots[-1] + + video_path = write_synthetic_video(project) + contexts = build_video_context_from_detector(project) + + video_iterator = videos_api.VideoIterator(video_path) + video_iterator.set_context(contexts) + + with ( + patch.object( + api_utils.PoseModel, + "build", + side_effect=lambda *args, **kwargs: TinyCornerPoseModel(), + ), + patch.object( + api_utils, + "build_transforms", + side_effect=lambda cfg: IdentityTopDownTransform(), + ), + ): + pose_runner = api_utils.get_pose_inference_runner( + model_config=loader.model_cfg, + snapshot_path=snapshot.path, + max_individuals=len(loader.model_cfg["metadata"]["individuals"]), + batch_size=1, + transform=None, + dynamic=None, + cond_provider=None, + ctd_tracking=False, + inference_cfg=None, + ) + + predictions = videos_api.video_inference( + video=video_iterator, + pose_runner=pose_runner, + detector_runner=None, # contexts already contain bboxes + shelf_writer=None, + robust_nframes=False, + show_gpu_memory=False, + ) + + # Basic sanity checks + assert len(predictions) == len(project.frames), ( + f"Expected {len(project.frames)} frame predictions, got {len(predictions)}" + ) + + for pred in predictions: + assert "bodyparts" in pred + bodyparts = pred["bodyparts"] + + # Expect one individual, four keypoints, xyz/conf + assert bodyparts.ndim == 3 + assert bodyparts.shape[1] == 4 + assert bodyparts.shape[2] >= 3 + + # Optionally also serialize a DLC-style H5 for the synthetic video + videos_api.create_df_from_prediction( + predictions=predictions, + dlc_scorer="synthetic_demo", + multi_animal=True, + model_cfg=loader.model_cfg, + output_path=project.project_root, + output_prefix="synthetic_video_demo", + save_as_csv=False, + ) + + return predictions + + +# ----------------------------------------------------------------------------- +# Main entry point +# ----------------------------------------------------------------------------- + + +def main(output_dir: str | Path | None = None, run_inference: bool = True) -> SyntheticProject: + owns_tmp = False + if output_dir is None: + output_dir = Path(tempfile.mkdtemp(prefix="dlc_synth_square_demo_")) + owns_tmp = True + else: + output_dir = Path(output_dir) + max_step = 4 if not run_inference else 5 + + project = make_synthetic_square_dlc_project(output_dir) + print(f"[1/{max_step}] Synthetic DLC project created at: {project.project_root}") + print(f" config.yaml: {project.config_path}") + print(f" pytorch_config.yaml:{project.pose_config_path}") + + bboxes = generate_precomputed_detector_boxes(project) + print(f"[2/{max_step}] Precomputed detector boxes written to: {project.precomputed_bboxes_path}") + print(f" train entries: {len(bboxes.train)}, test entries: {len(bboxes.test)}") + + verify_loader_uses_precomputed_boxes(project) + print(f"[3/{max_step}] Verified: real DLCLoader.create_dataset(...) uses the saved detector boxes.") + + model = run_train_network_demo(project) + print(f"[4/{max_step}] train_network(...) completed successfully.") + print(f" tiny model trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + if run_inference: + predictions = run_video_inference_demo(project) + print(f"[5/{max_step}] video_inference(...) completed successfully on {len(predictions)} synthetic frames.") + + if owns_tmp: + print("\nNote: a temporary project directory was created automatically.") + print(f" You can inspect it here: {project.project_root}") + + return project + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Synthetic DLC top-down training + inference demo with precomputed detector boxes." + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Directory in which to create the synthetic project. If omitted, a temporary directory is used.", + ) + parser.add_argument( + "--no-inference", + action="store_false", + dest="run_inference", + help="Skip the video inference demo after training.", + ) + args = parser.parse_args() + main(args.output_dir, run_inference=args.run_inference) diff --git a/scripts/external_detector_workflow.py b/scripts/external_detector_workflow.py new file mode 100644 index 000000000..fc7a640e7 --- /dev/null +++ b/scripts/external_detector_workflow.py @@ -0,0 +1,693 @@ +""" +External-detector workflow example for DeepLabCut PyTorch top-down pose estimation. + +This example is intended for those who already have a *real* DeepLabCut project +with labeled data and a created shuffle / PyTorch model folder. + +Description +----------------------------- +1. Open a normal DLC project with a real ``config.yaml``. +2. Choose a DLC pose model. +3. Remember to create a training shuffle using DLC ! +4. Plug in your own external detector by implementing a tiny adapter class. +5. Run the detector offline on the train/test images and save the results as + ``precomputed_bboxes.json``. +6. Create/update the project's ``pytorch_config.yaml`` so the pose model trains in + top-down mode using those precomputed boxes. +7. Train the DLC pose model via the ``train_network(...)`` API. +8. Run inference either on: + - a video (using per-frame bbox context, optionally cached to disk), or + - a folder of image frames. + +Purpose +----------------------- +The goal is to make it easy to use *your own detector* while keeping *DLC pose models* +for training and inference. +In this workflow, the detector is responsible only for +providing bounding boxes (proposals / crops), and DeepLabCut still handles: +- dataset loading, +- crop generation, +- pose-model training, +- snapshot management, +- and inference. + +Important prerequisites +----------------------- +Before using this script, you should already have: +1. a normal DeepLabCut project with labeled data (from RCP), +2. a created training dataset / shuffle for the PyTorch engine (provided or your own), +3. and a valid ``config.yaml``. + + +What you should edit +-------------------- +Users should mainly edit: +- ``CONFIG`` -> path to their DLC ``config.yaml`` +- ``POSE_MODEL`` -> which DLC pose model to use +- ``MyExternalDetector`` -> the detector adapter, where most of the work will happen +- a few curated training / crop settings in ``USER_SETTINGS`` + +What you usually should *not* edit (unless you want/have to) +---------------------------------- +- ``DLCLoader`` internals +- bbox artifact schema internals +- runner construction internals +- raw snapshot loading +- pose-model internals + +Example usage +(CLI if needed, but I'd suggest using a notebook for dev and debug. +RCP makes this easy, just import the script from your notebook and use the functions directly): +------------- +Train only: + + python external_detector_real_project_workflow.py --config /path/to/config.yaml --train + +Train + video inference: + + python external_detector_real_project_workflow.py \ + --config /path/to/config.yaml \ + --train \ + --video /path/to/video.mp4 + +Folder-of-frames inference: + + python external_detector_real_project_workflow.py \ + --config /path/to/config.yaml \ + --images-dir /path/to/frames +""" + +from __future__ import annotations + +import argparse +import pickle +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + +from deeplabcut.pose_estimation_pytorch.apis.training import train_network +from deeplabcut.pose_estimation_pytorch.apis.utils import get_pose_inference_runner +from deeplabcut.pose_estimation_pytorch.apis.videos import ( + VideoIterator, + create_df_from_prediction, + video_inference, +) +from deeplabcut.pose_estimation_pytorch.config.make_pose_config import make_pytorch_pose_config +from deeplabcut.pose_estimation_pytorch.data import DLCLoader +from deeplabcut.pose_estimation_pytorch.models.detectors.external.base import ( + precompute_detector_bboxes, +) +from deeplabcut.pose_estimation_pytorch.runners.inference import DetectorToPoseInferenceRunner +from deeplabcut.pose_estimation_pytorch.task import Task + +# ----------------------------------------------------------------------------- +# User-facing settings +# ----------------------------------------------------------------------------- + +EXAMPLE_POSE_MODELS = [ + "hrnet_w32", + "resnet_50", + "rtmpose_x", + "rtmpose_s", + "rtmpose_m", +] + + +@dataclass +class UserSettings: + pose_model: str = "resnet_50" + shuffle: int = 1 + trainingsetindex: int = 0 + batch_size: int = 4 + epochs: int = 50 + crop_width: int = 256 + crop_height: int = 256 + bbox_match_iou_threshold: float = 0.1 + bbox_fallback_to_gt: bool = False + bbox_validate_image_paths: bool = False + display_iters: int = 50 + device: str | None = None + + +# ----------------------------------------------------------------------------- +# Detector adapter section (participants should replace this with their own detector) +# ----------------------------------------------------------------------------- + + +class PretrainedDetectorModel: + """ + Replace the internals of this class with your own detector. + + Required contract: + inference(images, shelf_writer=None) -> list[dict] + + For each input image, return a dict in DLC detector-context format: + { + "bboxes": np.ndarray[N, 4], # XYWH in pixels + "bbox_scores": np.ndarray[N], + } + + Supported input image elements typically include: + - ``Path`` / ``str`` to an image file, + - ``np.ndarray`` image arrays, + - or ``(image, context)`` tuples. + + The simplest way to adapt your detector is: + 1. load the image if needed, + 2. run your detector, + 3. convert its output boxes to XYWH pixel coordinates, + 4. return the list of per-image dicts. + + Notes + ----- + - Boxes must be in ``xywh`` format because the current DLC top-down crop path + expects that downstream. + - If your detector naturally returns ``xyxy`` boxes, convert them before returning. + The pose_estimation_pytorch.data.bboxes.BBoxEntry schemas + already have converter functions in place, feel free to extend them. + """ + + def inference(self, images, shelf_writer=None): + raise NotImplementedError( + "Replace `MyExternalDetector.inference(...)` with your own detector adapter.\n" + "It must return a list of dicts with keys `bboxes` and `bbox_scores`, where\n" + "`bboxes` has shape [N, 4] in XYWH pixel coordinates." + ) + + +# ----------------------------------------------------------------------------- +# Small utility helpers +# ----------------------------------------------------------------------------- + + +def ensure_loader_get_image_paths() -> None: + """ + Compatibility shim for versions where precompute_detector_bboxes(...) expects a + loader.get_image_paths(...) method but Loader only exposes image_filenames(...). + """ + if not hasattr(DLCLoader, "get_image_paths"): + DLCLoader.get_image_paths = DLCLoader.image_filenames + + +def list_images_in_folder(images_dir: str | Path) -> list[Path]: + images_dir = Path(images_dir) + exts = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"} + paths = [p for p in sorted(images_dir.iterdir()) if p.suffix.lower() in exts and p.is_file()] + if len(paths) == 0: + raise FileNotFoundError(f"No supported image files found in {images_dir}") + return paths + + +def infer_top_down_flag(pose_model_name: str) -> bool: + """ + Heuristic for the config builder. + + - Backbone names need `top_down=True` to become TD pose models. + - Explicit top-down configs like `resnet_50` are already TD, but passing + `top_down=True` is harmless for the config-builder path. + - This example is specifically for top-down detector-driven workflows. + """ + return True + + +# ----------------------------------------------------------------------------- +# Config preparation helpers +# ----------------------------------------------------------------------------- + + +def prepare_external_topdown_pose_config( + config: str | Path, + settings: UserSettings, + precomputed_bboxes_path: str | Path, + external_detector_metadata: dict | None = None, + modelprefix: str = "", +) -> tuple[DLCLoader, Path]: + """ + Create/update the DLC PyTorch pose config for the external / precomputed detector workflow. + + This function: + 1. loads the real DLC project through DLCLoader, + 2. creates / overwrites the project's pytorch_config.yaml using make_pytorch_pose_config(...), + 3. applies a few curated updates relevant to this workflow. + """ + loader = DLCLoader( + config=config, + trainset_index=settings.trainingsetindex, + shuffle=settings.shuffle, + modelprefix=modelprefix, + ) + + pose_cfg = make_pytorch_pose_config( + project_config=loader.project_cfg, + pose_config_path=loader.model_config_path, + net_type=settings.pose_model, + top_down=infer_top_down_flag(settings.pose_model), + detector_mode="external", + save=True, + precomputed_bboxes=precomputed_bboxes_path, + bbox_source="detection_bbox", + external_detector_metadata=external_detector_metadata or {}, + ) + + # Validate the chosen model really resolves to top-down. + if Task(pose_cfg["method"]) != Task.TOP_DOWN: + raise ValueError( + f"The selected pose model '{settings.pose_model}' did not resolve to a top-down model. " + f"Choose a top-down-capable model. Recommended examples: {EXAMPLE_POSE_MODELS}" + ) + + # Apply curated configuration updates via the canonical loader.update_model_cfg(...) path. + cfg_updates = { + "data.precomputed_bboxes": Path(precomputed_bboxes_path).as_posix(), + "data.bbox_source": "detection_bbox", + "data.bbox_match_iou_threshold": settings.bbox_match_iou_threshold, + "data.bbox_fallback_to_gt": settings.bbox_fallback_to_gt, + "data.bbox_validate_image_paths": settings.bbox_validate_image_paths, + "data.train.top_down_crop.width": settings.crop_width, + "data.train.top_down_crop.height": settings.crop_height, + "data.inference.top_down_crop.width": settings.crop_width, + "data.inference.top_down_crop.height": settings.crop_height, + "train_settings.batch_size": settings.batch_size, + "train_settings.epochs": settings.epochs, + "train_settings.display_iters": settings.display_iters, + # detector training is disabled in the external/offline workflow + "detector.train_settings.epochs": 0, + } + + if settings.device is not None: + cfg_updates["device"] = settings.device + + loader.update_model_cfg(cfg_updates) + return loader, loader.model_config_path + + +# ----------------------------------------------------------------------------- +# Training helpers +# ----------------------------------------------------------------------------- + + +def save_external_detector_bboxes( + config: str | Path, + detector_runner, + settings: UserSettings, + output_file: str | Path, + modelprefix: str = "", +): + """ + Run the external detector on the train/test images of a real DLC project and save + the results as a reusable JSON bbox artifact. + """ + ensure_loader_get_image_paths() + loader = DLCLoader( + config=config, + trainset_index=settings.trainingsetindex, + shuffle=settings.shuffle, + modelprefix=modelprefix, + ) + + return precompute_detector_bboxes( + loader=loader, + detector_runner=detector_runner, + output_file=output_file, + modes=("train", "test"), + bbox_format="xywh", + ) + + +def train_external_topdown_pose_model( + config: str | Path, + settings: UserSettings, + modelprefix: str = "", +) -> None: + """ + Train the configured top-down pose model using the real DLC train_network(...) API. + """ + train_network( + config=config, + shuffle=settings.shuffle, + trainingsetindex=settings.trainingsetindex, + modelprefix=modelprefix, + device=settings.device, + batch_size=settings.batch_size, + epochs=settings.epochs, + display_iters=settings.display_iters, + ) + + +# ----------------------------------------------------------------------------- +# Inference helpers +# ----------------------------------------------------------------------------- + + +def _load_or_compute_video_box_context( + video_path: str | Path, + detector_runner, + cache_file: str | Path | None = None, +) -> list[dict[str, np.ndarray]]: + """ + Compute (or load) per-frame detector boxes for a video. + + If `cache_file` is provided and exists, contexts are loaded from it. + Otherwise, the detector is run on the video frames and the result is optionally + saved to `cache_file`. + + The cache is intentionally a simple pickle so participants can inspect / curate it. + """ + if cache_file is not None: + cache_file = Path(cache_file) + if cache_file.exists(): + with open(cache_file, "rb") as f: + return pickle.load(f) + + video_iter = VideoIterator(video_path) + contexts = detector_runner.inference(list(video_iter)) + + if cache_file is not None: + cache_file.parent.mkdir(parents=True, exist_ok=True) + with open(cache_file, "wb") as f: + pickle.dump(contexts, f, pickle.HIGHEST_PROTOCOL) + + return contexts + + +def analyze_video_with_external_boxes( + config: str | Path, + video: str | Path, + detector_runner, + settings: UserSettings, + modelprefix: str = "", + video_box_cache: str | Path | None = None, +): + """ + Run top-down video inference using offline / precomputed per-frame bbox context. + + This uses the current cleanest inference path for external detectors: + 1. compute or load per-frame bbox contexts, + 2. attach them to a VideoIterator, + 3. call video_inference(...) with detector_runner=None. + """ + loader = DLCLoader( + config=config, + trainset_index=settings.trainingsetindex, + shuffle=settings.shuffle, + modelprefix=modelprefix, + ) + + snapshots = loader.snapshots(detector=False, best_in_last=True) + if len(snapshots) == 0: + raise RuntimeError("No pose snapshots were found. Train the model first.") + snapshot = snapshots[-1] + + pose_runner = get_pose_inference_runner( + model_config=loader.model_cfg, + snapshot_path=snapshot.path, + batch_size=1, + device=settings.device, + max_individuals=len(loader.model_cfg["metadata"]["individuals"]), + transform=None, + dynamic=None, + cond_provider=None, + ctd_tracking=False, + inference_cfg=None, + ) + + contexts = _load_or_compute_video_box_context(video, detector_runner, cache_file=video_box_cache) + video_iterator = VideoIterator(video) + video_iterator.set_context(contexts) + + predictions = video_inference( + video=video_iterator, + pose_runner=pose_runner, + detector_runner=None, + cropping=None, + shelf_writer=None, + robust_nframes=False, + show_gpu_memory=False, + ) + + dlc_scorer = loader.scorer(snapshot) + output_path = Path(video).parent + output_prefix = Path(video).stem + dlc_scorer + "_external" + + create_df_from_prediction( + predictions=predictions, + dlc_scorer=dlc_scorer, + multi_animal=loader.project_cfg["multianimalproject"], + model_cfg=loader.model_cfg, + output_path=output_path, + output_prefix=output_prefix, + save_as_csv=False, + ) + + return predictions + + +def analyze_image_folder_with_external_boxes( + config: str | Path, + images_dir: str | Path, + detector_runner, + settings: UserSettings, + modelprefix: str = "", +): + """ + Run top-down inference on a folder of image frames. + + This uses the precomputed bbox context path directly by building a list of + `(image_path, context)` tuples and giving them to the pose runner. + """ + loader = DLCLoader( + config=config, + trainset_index=settings.trainingsetindex, + shuffle=settings.shuffle, + modelprefix=modelprefix, + ) + + snapshots = loader.snapshots(detector=False, best_in_last=True) + if len(snapshots) == 0: + raise RuntimeError("No pose snapshots were found. Train the model first.") + snapshot = snapshots[-1] + + pose_runner = get_pose_inference_runner( + model_config=loader.model_cfg, + snapshot_path=snapshot.path, + batch_size=1, + device=settings.device, + max_individuals=len(loader.model_cfg["metadata"]["individuals"]), + transform=None, + dynamic=None, + cond_provider=None, + ctd_tracking=False, + inference_cfg=None, + ) + + image_paths = list_images_in_folder(images_dir) + + composite_runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + max_individuals=len(loader.model_cfg["metadata"]["individuals"]), + num_joints=len(loader.model_cfg["metadata"]["bodyparts"]), + num_unique_bodyparts=len(loader.model_cfg["metadata"].get("unique_bodyparts", [])), + ) + + predictions = composite_runner.inference(image_paths) + + dlc_scorer = loader.scorer(snapshot) + output_path = Path(images_dir) + output_prefix = output_path.name + dlc_scorer + "_external" + + create_df_from_prediction( + predictions=predictions, + dlc_scorer=dlc_scorer, + multi_animal=loader.project_cfg["multianimalproject"], + model_cfg=loader.model_cfg, + output_path=output_path, + output_prefix=output_prefix, + save_as_csv=True, + ) + + return predictions + + +# ----------------------------------------------------------------------------- +# Main workflow +# ----------------------------------------------------------------------------- + + +def main( + config: str | Path, + settings: UserSettings, + train: bool = False, + video: str | Path | None = None, + images_dir: str | Path | None = None, + modelprefix: str = "", + video_box_cache: str | Path | None = None, +): + config = Path(config) + if not config.exists(): + raise FileNotFoundError(f"Config file not found: {config}") + + # Update the detector args + detector = PretrainedDetectorModel() + + # Build loader once to resolve the canonical model folder. + loader = DLCLoader( + config=config, + trainset_index=settings.trainingsetindex, + shuffle=settings.shuffle, + modelprefix=modelprefix, + ) + + bbox_file = loader.model_folder / "precomputed_bboxes.json" + + print("=== External detector + DLC top-down workflow ===") + print(f"Project config: {config}") + print(f"Shuffle: {settings.shuffle}") + print(f"Training set index: {settings.trainingsetindex}") + print(f"Pose model: {settings.pose_model}") + print(f"BBox artifact: {bbox_file}") + print() + + print("[1/4] Running external detector on the project images and saving offline boxes...") + save_external_detector_bboxes( + config=config, + detector_runner=detector, + settings=settings, + output_file=bbox_file, + modelprefix=modelprefix, + ) + print(" Done.") + + print("[2/4] Creating/updating pytorch_config.yaml for external top-down training...") + loader, pose_cfg_path = prepare_external_topdown_pose_config( + config=config, + settings=settings, + precomputed_bboxes_path=bbox_file, + external_detector_metadata={ + "name": detector.__class__.__name__, + "integration": "external_offline_boxes_example", + }, + modelprefix=modelprefix, + ) + print(f" Wrote pose config: {pose_cfg_path}") + + if train: + print("[3/4] Training the DLC pose model with offline detector boxes...") + train_external_topdown_pose_model(config=config, settings=settings, modelprefix=modelprefix) + print(" Training finished.") + else: + print("[3/4] Skipping training (--train not given).") + + if video is not None and images_dir is not None: + raise ValueError("Please provide either --video or --images-dir, not both.") + + if video is not None: + print("[4/4] Running video inference with offline boxes...") + preds = analyze_video_with_external_boxes( + config=config, + video=video, + detector_runner=detector, + settings=settings, + modelprefix=modelprefix, + video_box_cache=video_box_cache, + ) + print(f" Wrote predictions for {len(preds)} video frames.") + elif images_dir is not None: + print("[4/4] Running image-folder inference with offline boxes...") + preds = analyze_image_folder_with_external_boxes( + config=config, + images_dir=images_dir, + detector_runner=detector, + settings=settings, + modelprefix=modelprefix, + ) + print(f" Wrote predictions for {len(preds)} images.") + else: + print("[4/4] No inference target provided. Use --video or --images-dir to run inference.") + + print() + print("Workflow complete.") + print("Benchmark time :3") + + +# ----------------------------------------------------------------------------- +# CLI +# ----------------------------------------------------------------------------- + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="External detector + DLC top-down workflow example (offline boxes).") + parser.add_argument( + "--config", + type=str, + required=True, + help="Path to the DLC project config.yaml", + ) + parser.add_argument( + "--pose-model", + type=str, + default="top_down_resnet_50", + help=(f"DLC pose model to use. You can pass a raw DLC net_type. Recommended examples: {EXAMPLE_POSE_MODELS}"), + ) + parser.add_argument("--shuffle", type=int, default=1, help="Shuffle index") + parser.add_argument("--trainingsetindex", type=int, default=0, help="TrainingFraction index") + parser.add_argument("--batch-size", type=int, default=4, help="Pose training batch size") + parser.add_argument("--epochs", type=int, default=50, help="Pose training epochs") + parser.add_argument("--crop-width", type=int, default=256, help="Top-down crop width") + parser.add_argument("--crop-height", type=int, default=256, help="Top-down crop height") + parser.add_argument("--display-iters", type=int, default=50, help="Loss logging interval during training") + parser.add_argument("--device", type=str, default=None, help="Torch device override, e.g. cpu/cuda/mps") + parser.add_argument( + "--train", + action="store_true", + help="Run training after preparing the offline bbox artifact and pose config.", + ) + parser.add_argument( + "--video", + type=str, + default=None, + help="Optional path to a video on which to run inference using offline boxes.", + ) + parser.add_argument( + "--images-dir", + type=str, + default=None, + help="Optional path to a folder of image frames on which to run inference using offline boxes.", + ) + parser.add_argument( + "--video-box-cache", + type=str, + default=None, + help="Optional pickle cache for per-frame video detector boxes.", + ) + parser.add_argument( + "--modelprefix", + type=str, + default="", + help="Optional DLC modelprefix if your project uses one.", + ) + + args = parser.parse_args() + + settings = UserSettings( + pose_model=args.pose_model, + shuffle=args.shuffle, + trainingsetindex=args.trainingsetindex, + batch_size=args.batch_size, + epochs=args.epochs, + crop_width=args.crop_width, + crop_height=args.crop_height, + display_iters=args.display_iters, + device=args.device, + ) + + main( + config=args.config, + settings=settings, + train=args.train, + video=args.video, + images_dir=args.images_dir, + modelprefix=args.modelprefix, + video_box_cache=args.video_box_cache, + ) diff --git a/tests/pose_estimation_pytorch/apis/test_precomp_bbox_training.py b/tests/pose_estimation_pytorch/apis/test_precomp_bbox_training.py new file mode 100644 index 000000000..4ccf08d4d --- /dev/null +++ b/tests/pose_estimation_pytorch/apis/test_precomp_bbox_training.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset + +import deeplabcut.pose_estimation_pytorch.data.base as base_mod +from deeplabcut.pose_estimation_pytorch import build_training_runner +from deeplabcut.pose_estimation_pytorch.data.base import Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod, BBoxEntry, BBoxes +from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters +from deeplabcut.pose_estimation_pytorch.models.detectors.external import PrecomputedDetectorRunner +from deeplabcut.pose_estimation_pytorch.task import Task + +# ----------------------------------------------------------------------------- +# Tiny dataset stand-in so we can inspect create_dataset() output directly +# ----------------------------------------------------------------------------- + + +class DummyPoseDataset: + def __init__( + self, + images, + annotations, + transform, + mode, + task, + parameters, + ctd_config=None, + ): + self.images = images + self.annotations = annotations + self.transform = transform + self.mode = mode + self.task = task + self.parameters = parameters + self.ctd_config = ctd_config + + +@pytest.fixture(autouse=True) +def patch_pose_dataset(monkeypatch): + monkeypatch.setattr(base_mod, "PoseDataset", DummyPoseDataset) + + +# ----------------------------------------------------------------------------- +# Fake multi-animal DLC-style loader +# ----------------------------------------------------------------------------- + + +class FakeMultiAnimalDLCLoader(Loader): + """ + Minimal multi-animal loader: + - one image + - two individuals + - each individual has keypoints that imply a different bbox + """ + + def __init__(self, precomputed_bboxes_path: Path): + self.project_root = Path(".") + self.image_root = Path(".") + self.model_config_path = Path("dummy_pytorch_config.yaml") + + self.model_cfg = { + "method": "td", + "data": { + "bbox_source": BBoxComputationMethod.DETECTION_BBOX.value, + "precomputed_bboxes": precomputed_bboxes_path.as_posix(), + "bbox_margin": 5, + "bbox_match_iou_threshold": 0.1, + "bbox_fallback_to_gt": False, + }, + "runner": {}, + "train_settings": {}, + } + + self.pose_task = Task.TOP_DOWN + self._loaded_data = {} + + # Two individuals in one image, with clearly separated keypoints + # Individual A (left side) + keypoints_a = np.array( + [ + [20.0, 20.0, 2.0], + [30.0, 30.0, 2.0], + ], + dtype=np.float32, + ) + + # Individual B (right side) + keypoints_b = np.array( + [ + [70.0, 20.0, 2.0], + [80.0, 30.0, 2.0], + ], + dtype=np.float32, + ) + + self._payload = { + "images": [ + { + "id": 1, + "file_name": "img0.png", + "width": 100, + "height": 60, + } + ], + "annotations": [ + { + "id": 1, + "image_id": 1, + "category_id": 1, + "individual": "animal_a", + "individual_id": 0, + # placeholder/stale bbox - should be replaced + "bbox": np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + "area": 12.0, + "keypoints": keypoints_a, + "num_keypoints": 2, + "iscrowd": 0, + }, + { + "id": 2, + "image_id": 1, + "category_id": 1, + "individual": "animal_b", + "individual_id": 1, + # placeholder/stale bbox - should be replaced + "bbox": np.array([5.0, 6.0, 7.0, 8.0], dtype=np.float32), + "area": 56.0, + "keypoints": keypoints_b, + "num_keypoints": 2, + "iscrowd": 0, + }, + ], + } + + def load_data(self, mode: str = "train"): + self._loaded_data.setdefault(mode, self._payload) + return self._loaded_data[mode] + + def get_dataset_parameters(self) -> PoseDatasetParameters: + return PoseDatasetParameters( + bodyparts=["nose", "tail"], + unique_bpts=[], + individuals=["animal_a", "animal_b"], + with_center_keypoints=False, + color_mode="RGB", + top_down_crop_size=(64, 64), + top_down_crop_margin=0, + top_down_crop_with_context=True, + ) + + def default_bbox_method(self, task: Task): + # DLCLoader-like backward compatibility + if task in (Task.TOP_DOWN, Task.DETECT): + return BBoxComputationMethod.KEYPOINTS + return None + + +# ----------------------------------------------------------------------------- +# Tiny train dataset for PoseTrainingRunner +# ----------------------------------------------------------------------------- + + +class TinyTrainDataset(Dataset): + """ + Minimal dataset that yields the batch structure expected by PoseTrainingRunner. + + It uses the annotations produced by create_dataset(...), so training still depends + on the offline / precomputed detector assignment done earlier. + """ + + def __init__(self, annotations: list[dict]): + self.annotations = annotations + + def __len__(self): + return 2 + + def __getitem__(self, idx): + # Build keypoints tensor from the matched annotations + # shape: [num_individuals, num_bodyparts, 3] + kpts = np.stack([ann["keypoints"] for ann in self.annotations], axis=0).astype(np.float32) + + sample = { + "image": torch.zeros((3, 32, 32), dtype=torch.float32), + "annotations": { + "keypoints": torch.tensor(kpts, dtype=torch.float32), + "with_center_keypoints": torch.tensor(False), + }, + "offsets": torch.tensor([0.0, 0.0], dtype=torch.float32), + "scales": torch.tensor([1.0, 1.0], dtype=torch.float32), + "context": {}, + } + return sample + + +# ----------------------------------------------------------------------------- +# Tiny pose model compatible with PoseTrainingRunner +# ----------------------------------------------------------------------------- + + +class TinyPoseModel(nn.Module): + """ + Minimal trainable pose model: + - one scalar parameter + - produces dummy pose predictions + - implements the methods PoseTrainingRunner expects + """ + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.tensor(0.5, dtype=torch.float32)) + + # only needed if someone ever uses load_head_weights=False + self.backbone = nn.Identity() + + def forward(self, x, cond_kpts=None): + batch_size = x.shape[0] + + # Predict 2 individuals x 2 bodyparts x 3 values (x, y, visibility) + pred = torch.ones((batch_size, 2, 2, 3), device=x.device, dtype=torch.float32) * self.weight + pred[..., 2] = 1.0 + return {"pred_keypoints": pred} + + def get_target(self, outputs, annotations): + return annotations["keypoints"].to(outputs["pred_keypoints"].device).float() + + def get_loss(self, outputs, target): + pred_xy = outputs["pred_keypoints"][..., :2] + target_xy = target[..., :2] + loss = ((pred_xy - target_xy) ** 2).mean() + return {"total_loss": loss} + + def get_predictions(self, outputs): + return { + "bodypart": { + "poses": outputs["pred_keypoints"], + } + } + + +# ----------------------------------------------------------------------------- +# The actual end-to-end test +# ----------------------------------------------------------------------------- + + +def test_offline_precomputed_topdown_multi_animal_training_e2e(tmp_path: Path): + """ + End-to-end test for the offline / precomputed external detector workflow. + + Proves that: + 1. precomputed detector boxes can be loaded from config + 2. create_dataset(...) builds the correct multi-animal top-down dataset + 3. training runs through the high-level training API + 4. only the pose model is trained + 5. the detector is not needed anymore once the dataset is built + """ + + # ------------------------------------------------------------------------- + # 1. Create precomputed detector artifact with boxes intentionally reversed + # relative to annotation order. Matching must recover the correct assignment. + # ------------------------------------------------------------------------- + bboxes_path = tmp_path / "precomputed_bboxes.json" + + precomputed = BBoxes( + train=[ + BBoxEntry( + # reversed order on purpose: + # first bbox belongs to animal_b (right side), second to animal_a (left side) + bboxes=[ + (65.0, 15.0, 20.0, 20.0), # should match annotation 2 / animal_b + (15.0, 15.0, 20.0, 20.0), # should match annotation 1 / animal_a + ], + bbox_scores=[0.9, 0.8], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ], + test=[ + BBoxEntry( + bboxes=[ + (65.0, 15.0, 20.0, 20.0), + (15.0, 15.0, 20.0, 20.0), + ], + bbox_scores=[0.9, 0.8], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ], + ) + precomputed.dump_json(bboxes_path) + + # ------------------------------------------------------------------------- + # 2. Build loader + precomputed detector runner from config-like state + # ------------------------------------------------------------------------- + loader = FakeMultiAnimalDLCLoader(precomputed_bboxes_path=bboxes_path) + + detector_runner = PrecomputedDetectorRunner.from_bboxes( + BBoxes.from_file(bboxes_path), + mode="train", + target_format="xywh", + validate_image_paths=True, + ) + + # ------------------------------------------------------------------------- + # 3. Create top-down dataset using offline precomputed detector boxes + # ------------------------------------------------------------------------- + raw_before = [np.asarray(ann["bbox"], dtype=np.float32).copy() for ann in loader.load_data("train")["annotations"]] + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + # Annotation order is [animal_a, animal_b]. + # Matching should recover the correct detector box for each animal + # even though the detector outputs were stored in reversed order. + actual_bbox_a = np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32) + actual_bbox_b = np.asarray(dataset.annotations[1]["bbox"], dtype=np.float32) + + expected_bbox_a = np.asarray([15.0, 15.0, 20.0, 20.0], dtype=np.float32) + expected_bbox_b = np.asarray([65.0, 15.0, 20.0, 20.0], dtype=np.float32) + + np.testing.assert_allclose(actual_bbox_a, expected_bbox_a) + np.testing.assert_allclose(actual_bbox_b, expected_bbox_b) + + # Cached raw annotations must remain untouched + raw_after = [np.asarray(ann["bbox"], dtype=np.float32) for ann in loader.load_data("train")["annotations"]] + np.testing.assert_allclose(raw_before[0], raw_after[0]) + np.testing.assert_allclose(raw_before[1], raw_after[1]) + + # ------------------------------------------------------------------------- + # 4. Once dataset is built, training should no longer depend on detector I/O + # Prove this by making detector inference crash if called again. + # ------------------------------------------------------------------------- + def _should_not_be_called(*args, **kwargs): + raise AssertionError("Detector inference should not be called during pose training when using offline data.") + + detector_runner.inference = _should_not_be_called # type: ignore[method-assign] + + # ------------------------------------------------------------------------- + # 5. Build tiny train/valid loaders from the matched annotations + # ------------------------------------------------------------------------- + train_ds = TinyTrainDataset(dataset.annotations) + valid_ds = TinyTrainDataset(dataset.annotations) + + train_loader = DataLoader(train_ds, batch_size=1, shuffle=False) + valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False) + + # ------------------------------------------------------------------------- + # 6. Build high-level training runner + # ------------------------------------------------------------------------- + model = TinyPoseModel() + + runner_config = { + "optimizer": { + "type": "SGD", + "params": { + "lr": 0.1, + }, + }, + "eval_interval": 1, + "snapshots": { + "max_snapshots": 1, + "save_epochs": 1, + "save_optimizer_state": True, + }, + } + + model_folder = tmp_path / "models" + model_folder.mkdir(parents=True, exist_ok=True) + + runner = build_training_runner( + runner_config=runner_config, + model_folder=model_folder, + task=Task.TOP_DOWN, + model=model, + device="cpu", + snapshot_path=None, + ) + + # ------------------------------------------------------------------------- + # 7. Assert optimizer only contains trainable pose params + # ------------------------------------------------------------------------- + optimizer_param_ids = {id(p) for group in runner.optimizer.param_groups for p in group["params"]} + model_param_ids = {id(p) for p in model.parameters() if p.requires_grad} + + assert optimizer_param_ids == model_param_ids + + # ------------------------------------------------------------------------- + # 8. Run one short training cycle and assert pose params changed + # ------------------------------------------------------------------------- + before = {name: p.detach().cpu().clone() for name, p in model.named_parameters()} + + runner.fit( + train_loader=train_loader, + valid_loader=valid_loader, + epochs=1, + display_iters=1, + ) + + after = {name: p.detach().cpu() for name, p in model.named_parameters()} + + changed = [] + for name in before: + if not torch.equal(before[name], after[name]): + changed.append(name) + + assert len(changed) > 0, "Expected at least one pose model parameter to change during training." diff --git a/tests/pose_estimation_pytorch/data/test_bbox.py b/tests/pose_estimation_pytorch/data/test_bbox.py new file mode 100644 index 000000000..b5a569e3f --- /dev/null +++ b/tests/pose_estimation_pytorch/data/test_bbox.py @@ -0,0 +1,302 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Tests bbox-source behavior for dataset creation.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +import deeplabcut.pose_estimation_pytorch.data.base as base_mod +from deeplabcut.pose_estimation_pytorch.data.base import Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import BBoxComputationMethod +from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters +from deeplabcut.pose_estimation_pytorch.data.dlcloader import DLCLoader +from deeplabcut.pose_estimation_pytorch.data.utils import bbox_from_keypoints +from deeplabcut.pose_estimation_pytorch.task import Task + + +class DummyPoseDataset: + """Tiny stand-in for PoseDataset so we can inspect what create_dataset passes through.""" + + def __init__( + self, + images, + annotations, + transform, + mode, + task, + parameters, + ctd_config=None, + ): + self.images = images + self.annotations = annotations + self.transform = transform + self.mode = mode + self.task = task + self.parameters = parameters + self.ctd_config = ctd_config + + +class FakeDLCLoader(Loader): + """ + Minimal Loader used to test create_dataset() logic without needing a real DLC project. + It mimics DLCLoader's backward-compatible default bbox behavior. + """ + + def __init__(self, bbox_source: str | None = None): + # Do not call Loader.__init__() — we set just what create_dataset() needs. + self.project_root = Path(".") + self.image_root = Path(".") + self.model_config_path = Path("dummy_pytorch_config.yaml") + + self.model_cfg = { + "method": "td", + "data": { + "bbox_margin": 7, # IMPORTANT: used to test that configured margin is respected + }, + "train_settings": {}, + } + if bbox_source is not None: + self.model_cfg["data"]["bbox_source"] = bbox_source + + self.pose_task = Task.TOP_DOWN + self._loaded_data = {} + + # One cached payload, reused across calls — useful to detect accidental mutation + self._payload = { + "images": [ + { + "id": 1, + "file_name": "img0.png", + "width": 100, + "height": 80, + } + ], + "annotations": [ + { + "id": 1, + "image_id": 1, + "category_id": 1, + "individual": "animal", + "individual_id": 0, + # Placeholder bbox that should be replaced in keypoint mode + "bbox": np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + "area": 12.0, + # Two visible keypoints + "keypoints": np.array( + [ + [30.0, 40.0, 2.0], + [50.0, 60.0, 2.0], + ], + dtype=np.float32, + ), + "num_keypoints": 2, + "iscrowd": 0, + } + ], + } + + def load_data(self, mode: str = "train"): + self._loaded_data.setdefault(mode, self._payload) + return self._loaded_data[mode] + + def get_dataset_parameters(self) -> PoseDatasetParameters: + return PoseDatasetParameters( + bodyparts=["nose", "tail"], + unique_bpts=[], + individuals=["animal"], + with_center_keypoints=False, + color_mode="RGB", + top_down_crop_size=(256, 256), + top_down_crop_margin=0, + top_down_crop_with_context=True, + ) + + def default_bbox_method(self, task: Task) -> str | None: + # Mimic the new DLCLoader backward-compatible behavior + if task in (Task.TOP_DOWN, Task.DETECT): + return "keypoints" + return None + + +class DummyDetectorRunner: + """Simple detector runner returning one bbox per image.""" + + def __init__(self, bbox, score=0.9): + self._bbox = np.asarray(bbox, dtype=np.float32) + self._score = float(score) + + def inference(self, images, shelf_writer=None): + return [ + { + "bboxes": np.asarray([self._bbox], dtype=np.float32), + "bbox_scores": np.asarray([self._score], dtype=np.float32), + } + for _ in images + ] + + +@pytest.fixture(autouse=True) +def patch_pose_dataset(monkeypatch): + """ + Replace PoseDataset with a tiny dummy object so tests focus purely on loader logic. + """ + monkeypatch.setattr(base_mod, "PoseDataset", DummyPoseDataset) + + +def test_dlcloader_default_bbox_method_is_backward_compatible(): + """ + DLCLoader should preserve historical behavior: + detector and top-down tasks default to keypoint-derived boxes. + """ + loader = object.__new__(DLCLoader) + + assert DLCLoader.default_bbox_method(loader, Task.TOP_DOWN) == BBoxComputationMethod.KEYPOINTS + assert DLCLoader.default_bbox_method(loader, Task.DETECT) == BBoxComputationMethod.KEYPOINTS + assert DLCLoader.default_bbox_method(loader, Task.BOTTOM_UP) is None + + +@pytest.mark.parametrize("task", [Task.TOP_DOWN, Task.DETECT]) +def test_create_dataset_defaults_to_keypoints_for_dlc_style_loader(task): + """ + Backward compatibility regression test: + when no bbox_source is explicitly configured, a DLCLoader-like loader should + derive boxes from keypoints for TOP_DOWN and DETECT tasks. + """ + loader = FakeDLCLoader() + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=task, + detector_runner=None, + ) + + ann = dataset.annotations[0] + actual_bbox = np.asarray(ann["bbox"], dtype=np.float32) + + expected_bbox = bbox_from_keypoints( + keypoints=loader._payload["annotations"][0]["keypoints"], + image_h=loader._payload["images"][0]["height"], + image_w=loader._payload["images"][0]["width"], + margin=loader.model_cfg["data"]["bbox_margin"], + ).astype(np.float32) + + # Ensure configured bbox_margin is respected + np.testing.assert_allclose(actual_bbox, expected_bbox) + + # Stronger regression guard: + # this should NOT be the hardcoded margin=20 result from _add_bbox_annotations() + hardcoded_bbox = bbox_from_keypoints( + keypoints=loader._payload["annotations"][0]["keypoints"], + image_h=loader._payload["images"][0]["height"], + image_w=loader._payload["images"][0]["width"], + margin=20, + ).astype(np.float32) + + assert not np.allclose(actual_bbox, hardcoded_bbox), ( + "create_dataset() appears to be relying on the hardcoded bbox=20 fallback " + "instead of recomputing with configured bbox_margin" + ) + + +def test_explicit_bbox_source_gt_preserves_existing_bbox(): + """ + Explicit bbox_source='gt' must override the backward-compatible default and keep + the annotation bbox unchanged. + """ + loader = FakeDLCLoader(bbox_source="gt") + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=None, + ) + + ann = dataset.annotations[0] + actual_bbox = np.asarray(ann["bbox"], dtype=np.float32) + expected_bbox = np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32) + + np.testing.assert_allclose(actual_bbox, expected_bbox) + + +def test_detector_runner_overrides_default_bbox_source(): + """ + If a detector_runner is provided, create_dataset() must use detector boxes even if + the loader would otherwise default to keypoint-derived boxes. + """ + loader = FakeDLCLoader() + detector_runner = DummyDetectorRunner(bbox=[11.0, 12.0, 13.0, 14.0], score=0.95) + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + ann = dataset.annotations[0] + actual_bbox = np.asarray(ann["bbox"], dtype=np.float32) + + np.testing.assert_allclose(actual_bbox, np.asarray([11.0, 12.0, 13.0, 14.0], dtype=np.float32)) + + +def test_create_dataset_does_not_mutate_cached_load_data_annotations(): + """ + Regression test for the refactor: + create_dataset() should deep-copy annotations before rewriting bboxes, otherwise + cached load_data() results become stateful and unsafe across repeated calls. + """ + loader = FakeDLCLoader() + detector_runner = DummyDetectorRunner(bbox=[21.0, 22.0, 23.0, 24.0], score=0.88) + + # Sanity check original cached bbox + raw_before = np.asarray(loader.load_data("train")["annotations"][0]["bbox"], dtype=np.float32).copy() + np.testing.assert_allclose(raw_before, np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + + # This call should NOT mutate the cached payload + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + # Dataset bbox should use detector output + np.testing.assert_allclose( + np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32), + np.asarray([21.0, 22.0, 23.0, 24.0], dtype=np.float32), + ) + + # Cached raw annotations must remain untouched + raw_after = np.asarray(loader.load_data("train")["annotations"][0]["bbox"], dtype=np.float32) + np.testing.assert_allclose(raw_after, np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + + +def test_explicit_gt_is_still_overridden_by_detector_runner(): + loader = FakeDLCLoader(bbox_source="gt") + detector_runner = DummyDetectorRunner(bbox=[31.0, 32.0, 33.0, 34.0]) + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + np.testing.assert_allclose( + np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32), + np.asarray([31.0, 32.0, 33.0, 34.0], dtype=np.float32), + ) diff --git a/tests/pose_estimation_pytorch/models/external_detectors/test_build.py b/tests/pose_estimation_pytorch/models/external_detectors/test_build.py new file mode 100644 index 000000000..eb1e7fbf6 --- /dev/null +++ b/tests/pose_estimation_pytorch/models/external_detectors/test_build.py @@ -0,0 +1,105 @@ +import numpy as np + +from deeplabcut.pose_estimation_pytorch.data.postprocessor import ( + build_detector_postprocessor, +) +from deeplabcut.pose_estimation_pytorch.data.preprocessor import ( + build_bottom_up_preprocessor, +) +from deeplabcut.pose_estimation_pytorch.data.transforms import build_transforms +from deeplabcut.pose_estimation_pytorch.models.detectors.external import ( + EXTERNAL_DETECTORS, +) +from deeplabcut.pose_estimation_pytorch.runners.inference import build_inference_runner +from deeplabcut.pose_estimation_pytorch.task import Task + + +def test_external_detector_end_to_end_inference(): + """ + End-to-end test for external (pretrained) detector inference. + + This test verifies that: + - an external detector can be built from the registry + - preprocessing runs correctly + - DetectorInferenceRunner executes inference + - outputs have the expected DLC detection structure + """ + + # ------------------------- + # 1. Build the external detector + # ------------------------- + detector_cfg = { + "type": "MockExternalDetector", + "score": 0.9, + "label": 1, + } + + detector = EXTERNAL_DETECTORS.build(detector_cfg) + detector.eval() + + # ------------------------- + # 2. Build preprocessor & postprocessor + # ------------------------- + transform = build_transforms({"scale_to_unit_range": True}) + + preprocessor = build_bottom_up_preprocessor( + color_mode="RGB", + transform=transform, + ) + + postprocessor = build_detector_postprocessor( + max_individuals=5, + min_bbox_score=0.0, + ) + + # ------------------------- + # 3. Build inference runner (high-level API) + # ------------------------- + runner = build_inference_runner( + task=Task.DETECT, + model=detector, + device="cpu", + snapshot_path=None, # external detectors manage their own weights + batch_size=1, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + + # ------------------------- + # 4. Create mock input data + # ------------------------- + # Single RGB image (H, W, C) + image = np.zeros((128, 256, 3), dtype=np.uint8) + + # ------------------------- + # 5. Run inference + # ------------------------- + results = runner.inference([image]) + + # ------------------------- + # 6. Check outputs + # ------------------------- + assert isinstance(results, list) + assert len(results) == 1 + + det = results[0] + assert isinstance(det, dict) + assert "bboxes" in det + assert "bbox_scores" in det + + bboxes = det["bboxes"] + scores = det["bbox_scores"] + + assert isinstance(bboxes, np.ndarray) + assert isinstance(scores, np.ndarray) + + assert bboxes.shape == (1, 4) + assert scores.shape == (1,) + + # Check bbox sanity (MockExternalDetector returns centered box) + x1, y1, x2, y2 = bboxes[0] + assert x2 > x1 + assert y2 > y1 + + # Score sanity + assert np.isclose(scores[0], 0.9) diff --git a/tests/pose_estimation_pytorch/models/external_detectors/test_inference_wrapper.py b/tests/pose_estimation_pytorch/models/external_detectors/test_inference_wrapper.py new file mode 100644 index 000000000..26de710f0 --- /dev/null +++ b/tests/pose_estimation_pytorch/models/external_detectors/test_inference_wrapper.py @@ -0,0 +1,393 @@ +from __future__ import annotations + +from pathlib import Path + +import albumentations as A +import numpy as np +import pytest +import torch.nn as nn + +from deeplabcut.pose_estimation_pytorch.data.preprocessor import build_top_down_preprocessor +from deeplabcut.pose_estimation_pytorch.runners.inference import ( + DetectorToPoseInferenceRunner, + build_inference_runner, +) +from deeplabcut.pose_estimation_pytorch.task import Task + + +class DummyDetectorRunner: + """Simple detector runner stub returning predefined outputs.""" + + def __init__(self, outputs): + self.outputs = outputs + self.calls = [] + + def inference(self, images, shelf_writer=None): + images = list(images) + self.calls.append( + { + "images": images, + "shelf_writer": shelf_writer, + } + ) + return self.outputs + + +class RecordingPoseRunner: + """ + Minimal pose runner stub that records what it receives and returns a fixed result. + """ + + def __init__(self, return_value=None): + self.calls = [] + self.return_value = return_value if return_value is not None else [{"ok": True}] + + def inference(self, images, shelf_writer=None): + images = list(images) + self.calls.append( + { + "images": images, + "shelf_writer": shelf_writer, + } + ) + return self.return_value + + +class PreprocessingPoseRunner: + """ + Small integration-style pose runner that actually runs the real top-down preprocessor. + + This lets us verify that the wrapper injects context["bboxes"] in the exact form + expected by TopDownCrop. + """ + + def __init__(self, preprocessor): + self.preprocessor = preprocessor + self.calls = [] + + def inference(self, images, shelf_writer=None): + images = list(images) + self.calls.append( + { + "images": images, + "shelf_writer": shelf_writer, + } + ) + + outputs = [] + for item in images: + if isinstance(item, tuple): + image, context = item + else: + image, context = item, {} + + proc_image, proc_context = self.preprocessor(image, context) + + outputs.append( + { + "image_shape": tuple(proc_image.shape), + "num_bboxes": len(context["bboxes"]), + "offsets_shape": tuple(np.asarray(proc_context["offsets"]).shape), + "scales_shape": tuple(np.asarray(proc_context["scales"]).shape), + "top_down_crop_size": proc_context["top_down_crop_size"], + } + ) + + return outputs + + +def test_detector_then_pose_inference_injects_bboxes_and_preserves_context(): + detector_outputs = [ + { + "bboxes": np.array([[10.0, 20.0, 30.0, 40.0]], dtype=np.float32), + "bbox_scores": np.array([0.9], dtype=np.float32), + }, + { + "bboxes": np.array( + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + ], + dtype=np.float32, + ), + "bbox_scores": np.array([0.7, 0.8], dtype=np.float32), + }, + ] + + detector_runner = DummyDetectorRunner(detector_outputs) + pose_runner = RecordingPoseRunner(return_value=[{"poses": "ok"}]) + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + original_context0 = {"foo": "bar"} + original_context1 = {"answer": 42} + + images = [ + ("img0.png", original_context0), + (Path("img1.png"), original_context1), + ] + + results = runner.inference(images) + + assert results == [{"poses": "ok"}] + + # Detector gets raw image inputs only; incoming contexts are preserved + # and forwarded to the pose runner after bbox injection. + assert len(detector_runner.calls) == 1 + assert detector_runner.calls[0]["images"] == ["img0.png", Path("img1.png")] + + # Pose runner got enriched inputs + assert len(pose_runner.calls) == 1 + enriched = pose_runner.calls[0]["images"] + assert len(enriched) == 2 + + image0, context0 = enriched[0] + assert image0 == "img0.png" + assert context0["foo"] == "bar" + np.testing.assert_allclose( + context0["bboxes"], + np.array([[10.0, 20.0, 30.0, 40.0]], dtype=np.float32), + ) + np.testing.assert_allclose( + context0["bbox_scores"], + np.array([0.9], dtype=np.float32), + ) + assert context0["detector_output"] is detector_outputs[0] + + image1, context1 = enriched[1] + assert image1 == Path("img1.png") + assert context1["answer"] == 42 + np.testing.assert_allclose( + context1["bboxes"], + np.array( + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + ], + dtype=np.float32, + ), + ) + np.testing.assert_allclose( + context1["bbox_scores"], + np.array([0.7, 0.8], dtype=np.float32), + ) + assert context1["detector_output"] is detector_outputs[1] + + # Original input contexts should remain untouched + assert original_context0 == {"foo": "bar"} + assert original_context1 == {"answer": 42} + + +def test_detector_then_pose_inference_defaults_bbox_scores_when_missing(): + detector_outputs = [ + { + "bboxes": np.array( + [ + [10.0, 20.0, 30.0, 40.0], + [50.0, 60.0, 70.0, 80.0], + ], + dtype=np.float32, + ) + } + ] + + detector_runner = DummyDetectorRunner(detector_outputs) + pose_runner = RecordingPoseRunner() + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + runner.inference(["img0.png"]) + + enriched = pose_runner.calls[0]["images"] + _, context = enriched[0] + + np.testing.assert_allclose( + context["bbox_scores"], + np.array([1.0, 1.0], dtype=np.float32), + ) + + +def test_detector_then_pose_inference_handles_no_detections(): + detector_outputs = [ + { + "bboxes": np.zeros((0, 4), dtype=np.float32), + } + ] + + detector_runner = DummyDetectorRunner(detector_outputs) + pose_runner = RecordingPoseRunner() + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + runner.inference(["img0.png"]) + + enriched = pose_runner.calls[0]["images"] + _, context = enriched[0] + + assert isinstance(context["bboxes"], np.ndarray) + assert isinstance(context["bbox_scores"], np.ndarray) + assert context["bboxes"].shape == (0, 4) + assert context["bbox_scores"].shape == (0,) + + +def test_detector_then_pose_inference_raises_on_output_count_mismatch(): + detector_runner = DummyDetectorRunner( + [ + { + "bboxes": np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), + "bbox_scores": np.array([0.9], dtype=np.float32), + } + ] + ) + pose_runner = RecordingPoseRunner() + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + with pytest.raises(ValueError, match="Detector returned 1 outputs for 2 input images"): + runner.inference(["img0.png", "img1.png"]) + + # Pose runner should not be called if detector output count is invalid + assert len(pose_runner.calls) == 0 + + +def test_detector_then_pose_inference_raises_on_invalid_bbox_score_length(): + detector_outputs = [ + { + "bboxes": np.array( + [ + [10.0, 20.0, 30.0, 40.0], + [50.0, 60.0, 70.0, 80.0], + ], + dtype=np.float32, + ), + "bbox_scores": np.array([0.5], dtype=np.float32), # wrong length + } + ] + + detector_runner = DummyDetectorRunner(detector_outputs) + pose_runner = RecordingPoseRunner() + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + with pytest.raises(ValueError, match="Expected one bbox score per bbox"): + runner.inference(["img0.png"]) + + assert len(pose_runner.calls) == 0 + + +def test_detector_then_pose_inference_passes_shelf_writer_through(): + detector_outputs = [ + { + "bboxes": np.array([[10.0, 20.0, 30.0, 40.0]], dtype=np.float32), + "bbox_scores": np.array([0.9], dtype=np.float32), + } + ] + + detector_runner = DummyDetectorRunner(detector_outputs) + pose_runner = RecordingPoseRunner() + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + shelf_writer = object() + runner.inference(["img0.png"], shelf_writer=shelf_writer) + + assert detector_runner.calls[0]["shelf_writer"] is None + assert pose_runner.calls[0]["shelf_writer"] is shelf_writer + + +def test_detector_then_pose_integration_with_real_top_down_preprocessor(): + """ + Integration-style test: + prove that wrapper-injected context["bboxes"] is consumed by the real top-down + preprocessor and produces a crop batch of shape [num_individuals, 3, H, W]. + """ + preprocessor = build_top_down_preprocessor( + color_mode="RGB", + transform=A.Compose( + [], + bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]), + ), + top_down_crop_size=(32, 24), # width, height + top_down_crop_margin=0, + top_down_crop_with_context=True, + ) + + pose_runner = PreprocessingPoseRunner(preprocessor=preprocessor) + + detector_runner = DummyDetectorRunner( + [ + { + "bboxes": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [40.0, 15.0, 30.0, 25.0], + ], + dtype=np.float32, + ), + "bbox_scores": np.array([0.8, 0.9], dtype=np.float32), + } + ] + ) + + runner = DetectorToPoseInferenceRunner( + pose_runner=pose_runner, + detector_runner=detector_runner, + ) + + image = np.zeros((100, 120, 3), dtype=np.uint8) + + outputs = runner.inference([image]) + + assert len(outputs) == 1 + out = outputs[0] + + # ToTensor converts NHWC -> NCHW + assert out["image_shape"] == (2, 3, 24, 32) + assert out["num_bboxes"] == 2 + + # Offsets/scales are produced per crop + assert out["offsets_shape"] == (2, 2) + assert out["scales_shape"] == (2, 2) + + # TopDownCrop stores output_size as (width, height) + assert out["top_down_crop_size"] == (32, 24) + + +class TinyModel(nn.Module): + def forward(self, x, **kwargs): + return x + + +def test_build_inference_runner_wraps_top_down_runner_when_detector_runner_is_given(): + model = TinyModel() + detector_runner = DummyDetectorRunner(outputs=[]) + + runner = build_inference_runner( + task=Task.TOP_DOWN, + model=model, + device="cpu", + snapshot_path=None, + batch_size=1, + preprocessor=None, + postprocessor=None, + detector_runner=detector_runner, + ) + + assert isinstance(runner, DetectorToPoseInferenceRunner) diff --git a/tests/pose_estimation_pytorch/models/external_detectors/test_precomputed_bbox.py b/tests/pose_estimation_pytorch/models/external_detectors/test_precomputed_bbox.py new file mode 100644 index 000000000..2c1306f5b --- /dev/null +++ b/tests/pose_estimation_pytorch/models/external_detectors/test_precomputed_bbox.py @@ -0,0 +1,463 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Tests bbox schema + precomputed detector runner integration with DLC code.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +import deeplabcut.pose_estimation_pytorch.data.base as base_mod +from deeplabcut.pose_estimation_pytorch.data.base import Loader +from deeplabcut.pose_estimation_pytorch.data.bboxes import ( + BBoxEntry, + BBoxes, +) +from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters +from deeplabcut.pose_estimation_pytorch.data.postprocessor import build_detector_postprocessor +from deeplabcut.pose_estimation_pytorch.data.preprocessor import build_bottom_up_preprocessor +from deeplabcut.pose_estimation_pytorch.data.transforms import build_transforms +from deeplabcut.pose_estimation_pytorch.models.detectors.external import EXTERNAL_DETECTORS, PrecomputedDetectorRunner + +# Important: ensure the mock detector module is imported so registry population happens +from deeplabcut.pose_estimation_pytorch.runners.inference import build_inference_runner +from deeplabcut.pose_estimation_pytorch.task import Task + + +class DummyPoseDataset: + """ + Tiny stand-in for PoseDataset so tests can inspect what create_dataset() + actually passes through without depending on the real dataset internals. + """ + + def __init__( + self, + images, + annotations, + transform, + mode, + task, + parameters, + ctd_config=None, + ): + self.images = images + self.annotations = annotations + self.transform = transform + self.mode = mode + self.task = task + self.parameters = parameters + self.ctd_config = ctd_config + + +class FakeDLCLoader(Loader): + """ + Minimal loader for testing create_dataset() logic. + + It mimics DLCLoader’s backward-compatible behavior: + top-down and detect tasks default to keypoint-derived boxes unless a + detector_runner is provided. + """ + + def __init__(self, bbox_source: str | None = None): + # Avoid calling Loader.__init__() because we want a tiny controlled fixture + self.project_root = Path(".") + self.image_root = Path(".") + self.model_config_path = Path("dummy_pytorch_config.yaml") + self.model_cfg = { + "method": "td", + "data": { + "bbox_margin": 7, + }, + "train_settings": {}, + } + if bbox_source is not None: + self.model_cfg["data"]["bbox_source"] = bbox_source + + self.pose_task = Task.TOP_DOWN + self._loaded_data = {} + + # Cached payload, reused across calls. + # This lets us test that create_dataset() does NOT mutate cached load_data(). + self._payload = { + "images": [ + { + "id": 1, + "file_name": "img0.png", + "width": 256, + "height": 128, + } + ], + "annotations": [ + { + "id": 1, + "image_id": 1, + "category_id": 1, + "individual": "animal", + "individual_id": 0, + # Placeholder bbox that should be overridden + "bbox": np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + "area": 12.0, + "keypoints": np.array( + [ + [30.0, 40.0, 2.0], + [50.0, 60.0, 2.0], + ], + dtype=np.float32, + ), + "num_keypoints": 2, + "iscrowd": 0, + } + ], + } + + def load_data(self, mode: str = "train"): + self._loaded_data.setdefault(mode, self._payload) + return self._loaded_data[mode] + + def get_dataset_parameters(self) -> PoseDatasetParameters: + return PoseDatasetParameters( + bodyparts=["nose", "tail"], + unique_bpts=[], + individuals=["animal"], + with_center_keypoints=False, + color_mode="RGB", + top_down_crop_size=(256, 256), + top_down_crop_margin=0, + top_down_crop_with_context=True, + ) + + def default_bbox_method(self, task: Task) -> str | None: + # Mimic DLCLoader backward compatibility + if task in (Task.TOP_DOWN, Task.DETECT): + return "keypoints" + return None + + +@pytest.fixture(autouse=True) +def patch_pose_dataset(monkeypatch): + """ + Replace PoseDataset with a tiny dummy object so tests focus on loader logic. + """ + monkeypatch.setattr(base_mod, "PoseDataset", DummyPoseDataset) + + +def test_bbox_entry_from_detector_context_roundtrip_xywh(): + """ + Schema should faithfully round-trip DLC detector context in xywh format. + """ + context = { + "bboxes": np.array([[10.0, 20.0, 30.0, 40.0]], dtype=np.float32), + "bbox_scores": np.array([0.9], dtype=np.float32), + } + + entry = BBoxEntry.from_detector_context( + context, + image_path=Path("img0.png"), + bbox_format="xywh", + ) + + assert entry.image_path == Path("img0.png") + assert entry.bbox_format == "xywh" + assert np.allclose(entry.bbox_scores, [0.9]) + + restored = entry.to_detector_context(target_format="xywh") + np.testing.assert_allclose(restored["bboxes"], context["bboxes"]) + np.testing.assert_allclose(restored["bbox_scores"], context["bbox_scores"]) + + +def test_precomputed_detector_runner_inference_matches_dlc_contract(): + """ + PrecomputedDetectorRunner should behave like a DLC detector runner: + inference(images) -> list[{"bboxes": ..., "bbox_scores": ...}] + """ + bboxes = BBoxes( + train=[ + BBoxEntry( + bboxes=[(1.0, 2.0, 3.0, 4.0)], + bbox_scores=[0.8], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ] + ) + + runner = PrecomputedDetectorRunner.from_bboxes( + bboxes, + mode="train", + target_format="xywh", + validate_image_paths=True, + ) + + outputs = runner.inference([Path("img0.png")]) + + assert isinstance(outputs, list) + assert len(outputs) == 1 + assert "bboxes" in outputs[0] + assert "bbox_scores" in outputs[0] + + np.testing.assert_allclose( + outputs[0]["bboxes"], + np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32), + ) + np.testing.assert_allclose( + outputs[0]["bbox_scores"], + np.array([0.8], dtype=np.float32), + ) + + +def test_create_dataset_accepts_precomputed_detector_runner(): + """ + DLC loader.create_dataset(...) should be able to consume PrecomputedDetectorRunner + and rewrite annotation bboxes accordingly. + """ + loader = FakeDLCLoader() + + precomputed = BBoxes( + train=[ + BBoxEntry( + bboxes=[(11.0, 12.0, 13.0, 14.0)], + bbox_scores=[0.95], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ] + ) + + detector_runner = PrecomputedDetectorRunner.from_bboxes( + precomputed, + mode="train", + target_format="xywh", + validate_image_paths=True, + ) + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + ann = dataset.annotations[0] + actual_bbox = np.asarray(ann["bbox"], dtype=np.float32) + + np.testing.assert_allclose( + actual_bbox, + np.array([11.0, 12.0, 13.0, 14.0], dtype=np.float32), + ) + + +def test_create_dataset_with_precomputed_detector_runner_does_not_mutate_cached_load_data(): + """ + Regression test: create_dataset() must deep-copy cached annotations before rewriting + bboxes, otherwise load_data() becomes stateful and unsafe. + """ + loader = FakeDLCLoader() + + precomputed = BBoxes( + train=[ + BBoxEntry( + bboxes=[(21.0, 22.0, 23.0, 24.0)], + bbox_scores=[0.88], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ] + ) + + detector_runner = PrecomputedDetectorRunner.from_bboxes( + precomputed, + mode="train", + target_format="xywh", + validate_image_paths=True, + ) + + raw_before = np.asarray(loader.load_data("train")["annotations"][0]["bbox"], dtype=np.float32).copy() + np.testing.assert_allclose(raw_before, np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=detector_runner, + ) + + np.testing.assert_allclose( + np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32), + np.array([21.0, 22.0, 23.0, 24.0], dtype=np.float32), + ) + + # Cached raw annotations must remain untouched + raw_after = np.asarray(loader.load_data("train")["annotations"][0]["bbox"], dtype=np.float32) + np.testing.assert_allclose(raw_after, np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + + +def test_live_mock_detector_can_roundtrip_through_schema_and_precomputed_runner(): + """ + Strong integration test: + + live mock external detector + -> DLC DetectorInferenceRunner + -> detector context + -> BBoxEntry.from_detector_context(...) + -> PrecomputedDetectorRunner + -> loader.create_dataset(..., detector_runner=...) + + This proves the schema/adapter layer can bridge live detector outputs + back into DLC’s training/data path. + """ + # 1. Build live mock external detector + detector = EXTERNAL_DETECTORS.build( + { + "type": "MockExternalDetector", + "score": 0.9, + "label": 1, + } + ) + detector.eval() + + # 2. Build DLC detector inference runner around it + transform = build_transforms({"scale_to_unit_range": True}) + + runner = build_inference_runner( + task=Task.DETECT, + model=detector, + device="cpu", + snapshot_path=None, + batch_size=1, + preprocessor=build_bottom_up_preprocessor( + color_mode="RGB", + transform=transform, + ), + postprocessor=build_detector_postprocessor( + max_individuals=1, + min_bbox_score=0.0, + ), + ) + + # 3. Run detector on a mock image + image = np.zeros((128, 256, 3), dtype=np.uint8) + live_outputs = runner.inference([image]) + + assert len(live_outputs) == 1 + live_context = live_outputs[0] + + assert "bboxes" in live_context + assert "bbox_scores" in live_context + + # 4. Convert live DLC detector output -> schema + entry = BBoxEntry.from_detector_context( + live_context, + image_path=Path("img0.png"), + bbox_format="xywh", # DLC postprocessed outputs are expected here + ) + + # 5. Build precomputed runner from that schema + precomputed = BBoxes(train=[entry]) + + precomputed_runner = PrecomputedDetectorRunner.from_bboxes( + precomputed, + mode="train", + target_format="xywh", + validate_image_paths=True, + ) + + # 6. Use in DLC create_dataset(...) + loader = FakeDLCLoader() + dataset = loader.create_dataset( + transform=None, + mode="train", + task=Task.TOP_DOWN, + detector_runner=precomputed_runner, + ) + + actual_bbox = np.asarray(dataset.annotations[0]["bbox"], dtype=np.float32) + expected_bbox = np.asarray(live_context["bboxes"][0], dtype=np.float32) + + np.testing.assert_allclose(actual_bbox, expected_bbox) + + +def test_precomputed_detector_runner_supports_path_based_subset_lookup(): + bboxes = BBoxes( + train=[ + BBoxEntry( + bboxes=[(1.0, 2.0, 3.0, 4.0)], + bbox_scores=[0.1], + bbox_format="xywh", + image_path=Path("img0.png"), + ), + BBoxEntry( + bboxes=[(5.0, 6.0, 7.0, 8.0)], + bbox_scores=[0.9], + bbox_format="xywh", + image_path=Path("img1.png"), + ), + ] + ) + + runner = PrecomputedDetectorRunner.from_bboxes(bboxes, mode="train") + + outputs = runner.inference([Path("img1.png")]) + + assert len(outputs) == 1 + np.testing.assert_allclose( + outputs[0]["bboxes"], + np.array([[5.0, 6.0, 7.0, 8.0]], dtype=np.float32), + ) + np.testing.assert_allclose( + outputs[0]["bbox_scores"], + np.array([0.9], dtype=np.float32), + ) + + +def test_precomputed_detector_runner_preserves_requested_path_order(): + bboxes = BBoxes( + train=[ + BBoxEntry( + bboxes=[(1.0, 2.0, 3.0, 4.0)], + bbox_scores=[0.1], + bbox_format="xywh", + image_path=Path("img0.png"), + ), + BBoxEntry( + bboxes=[(5.0, 6.0, 7.0, 8.0)], + bbox_scores=[0.9], + bbox_format="xywh", + image_path=Path("img1.png"), + ), + ] + ) + + runner = PrecomputedDetectorRunner.from_bboxes(bboxes, mode="train") + + outputs = runner.inference([Path("img1.png"), Path("img0.png")]) + + np.testing.assert_allclose(outputs[0]["bboxes"], np.array([[5.0, 6.0, 7.0, 8.0]], dtype=np.float32)) + np.testing.assert_allclose(outputs[1]["bboxes"], np.array([[1.0, 2.0, 3.0, 4.0]], dtype=np.float32)) + + +def test_precomputed_detector_runner_raises_for_unknown_requested_path(): + bboxes = BBoxes( + train=[ + BBoxEntry( + bboxes=[(1.0, 2.0, 3.0, 4.0)], + bbox_scores=[0.1], + bbox_format="xywh", + image_path=Path("img0.png"), + ) + ] + ) + + runner = PrecomputedDetectorRunner.from_bboxes(bboxes, mode="train") + + with pytest.raises(ValueError, match="No precomputed bbox entry found"): + runner.inference([Path("missing.png")])