diff --git a/.circleci/config.yml b/.circleci/config.yml
index 00da359495..a8cb967fab 100644
--- a/.circleci/config.yml
+++ b/.circleci/config.yml
@@ -7,7 +7,7 @@ jobs:
build-and-test:
working_directory: ~/circleci-demo-python-django
docker:
- - image: circleci/python:3.8 # primary container for the build job
+ - image: circleci/python:3.10 # primary container for the build job
auth:
username: mydockerhub-user
password: $DOCKERHUB_PASSWORD # context / project UI env-var reference
diff --git a/.github/workflows/publish-book.yml b/.github/workflows/publish-book.yml
index 426e87370e..6aff2725d2 100644
--- a/.github/workflows/publish-book.yml
+++ b/.github/workflows/publish-book.yml
@@ -11,15 +11,15 @@ jobs:
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.9
+ - name: Set up Python 3.10
uses: actions/setup-python@v4
with:
- python-version: 3.9
+ python-version: 3.10
- name: Install dependencies
run: |
python -m pip install --upgrade pip
- python -m pip install .[tf,docs]
+ python -m pip install .[docs]
pip install jupyter-book sphinxcontrib-mermaid
- name: Build the book
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 04a4580456..c432acde6f 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -64,3 +64,5 @@ jobs:
pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }}
python examples/testscript.py
python examples/testscript_multianimal.py
+ python examples/testscript_pytorch_single_animal.py
+ python examples/testscript_pytorch_multi_animal.py
diff --git a/.gitignore b/.gitignore
index 86a6943487..ad9f192703 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,6 +10,7 @@ _build/*
/examples/3D*
/examples/m3*
/examples/OUT
+/examples/pretrained*
.local
.DS_Store
examples/.DS_Store
@@ -18,6 +19,15 @@ examples/.DS_Store
*.ckpt
snapshot-*
+# Modelzoo checkpoints
+deeplabcut/modelzoo/checkpoints/
+
+# PyTorch backbone weights
+deeplabcut/pose_estimation_pytorch/models/backbones/pretrained_weights/
+
+# Wandb files
+wandb/
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
@@ -115,7 +125,10 @@ ENV/
# Spyder project settings
.spyderproject
.spyproject
+
+# IDEs configurations
.vscode/*
+.idea/*
# Rope project settings
.ropeproject
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000..a21cfeebbd
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,18 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: check-added-large-files
+ - id: check-yaml
+ - id: end-of-file-fixer
+ - id: name-tests-test
+ - id: trailing-whitespace
+ - repo: https://github.com/PyCQA/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ language_version: python3
diff --git a/NOTICE.yml b/NOTICE.yml
index 68ff6362df..ec8df871b2 100644
--- a/NOTICE.yml
+++ b/NOTICE.yml
@@ -5,7 +5,7 @@
https://github.com/DeepLabCut/DeepLabCut
Please see AUTHORS for contributors.
- https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+ https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
Licensed under GNU Lesser General Public License v3.0
include:
@@ -17,6 +17,7 @@
# License for files adapted from DeeperCut by Eldar Insafutdinov
# https://github.com/eldar/pose-tensorflow
+
# Applies to most files in deeplabcut.pose_estimation_tensorflow
- header: |
DeepLabCut Toolbox (deeplabcut.org)
@@ -107,3 +108,8 @@
include:
- deeplabcut/pose_tracking_pytorch/solver/scheduler_factory.py
- deeplabcut/pose_tracking_pytorch/model/backones/vit_pytorch.py
+
+# PyTorch license
+
+- header: |
+ See https://github.com/pytorch/pytorch/blob/main/LICENSE
diff --git a/README.md b/README.md
index eea3ad6206..3d8319965a 100644
--- a/README.md
+++ b/README.md
@@ -60,15 +60,24 @@
Please click the link above for all the information you need to get started! Please note that currently we support only Python 3.10+ (see conda files for guidance).
-Developers Stable Release:
-- Very quick start: You need to have TensorFlow installed (up to v2.10 supported across platforms) `pip install "deeplabcut[gui,tf]"` that includes all functions plus GUIs, or `pip install deeplabcut[tf]` (headless version with PyTorch and TensorFlow).
-
-Developers Alpha Release:
-- We also have an alpha release of PyTorch DeepLabCut available! [Please see here for instructions and information](https://github.com/DeepLabCut/DeepLabCut/blob/pytorch_docs/docs/pytorch/user_guide.md).
+Developers Stable Release: very quick start (Python 3.10+ required) to install
+DeepLabCut with the PyTorch engine
+
+- [Install PyTorch](https://pytorch.org/get-started/locally/) (**select the desired
+CUDA version if you want to use a GPU**): `pip install torch torchvision`
+- Then, [install `pytables`](https://www.pytables.org/usersguide/installation.html): `conda install -c conda-forge pytables==3.8.0`
+- Finally, install `DeepLabCut` (with all functions + the GUI):
+`pip install --pre "deeplabcut[gui]"` or `pip install --pre "deeplabcut"` (headless
+version with PyTorch)!
+
+To use the TensorFlow engine (requires Python 3.10; TF up to v2.10 supported on Windows,
+up to v2.12 on other platforms): you'll need to run `pip install "deeplabcut[gui,tf]"`
+(which includes all functions plus GUIs) or `pip install "deeplabcut[tf]"` (headless
+version with PyTorch and TensorFlow).
We recommend using our conda file, see [here](https://github.com/DeepLabCut/DeepLabCut/blob/main/conda-environments/README.md) or the new [`deeplabcut-docker` package](https://github.com/DeepLabCut/DeepLabCut/tree/main/docker).
-# [Documentation: The DeepLabCut Process](https://deeplabcut.github.io/DeepLabCut)
+# [Documentation: The DeepLabCut Process](https://deeplabcut.github.io/DeepLabCut/README.html)
Our docs walk you through using DeepLabCut, and key API points. For an overview of the toolbox and workflow for project management, see our step-by-step at [Nature Protocols paper](https://doi.org/10.1038/s41596-019-0176-0).
@@ -82,9 +91,7 @@ For a deeper understanding and more resources for you to get started with Python
🐭 pose tracking of single animals demo [](https://colab.research.google.com/github/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb)
-🐭🐭🐭 pose tracking of multiple animals demo [](https://colab.research.google.com/github/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_3miceDemo.ipynb)
-
-- See [more demos here](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/README.md). We provide data and several Jupyter Notebooks: one that walks you through a demo dataset to test your installation, and another Notebook to run DeepLabCut from the beginning on your own data. We also show you how to use the code in Docker, and on Google Colab.
+See [more demos here](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/README.md). We provide data and several Jupyter Notebooks: one that walks you through a demo dataset to test your installation, and another Notebook to run DeepLabCut from the beginning on your own data. We also show you how to use the code in Docker, and on Google Colab.
# Why use DeepLabCut?
diff --git a/_config.yml b/_config.yml
index c55ef8cebb..a24fcaf718 100644
--- a/_config.yml
+++ b/_config.yml
@@ -5,7 +5,7 @@ only_build_toc_files: true
sphinx:
config:
- autodoc_mock_imports: ["wx"]
+ autodoc_mock_imports: ["wx", "matplotlib", "qtpy", "PySide6", "napari", "shiboken6"]
mermaid_output_format: raw
extra_extensions:
- numpydoc
diff --git a/_toc.yml b/_toc.yml
index 471655454f..8935c8ff8e 100644
--- a/_toc.yml
+++ b/_toc.yml
@@ -46,10 +46,8 @@ parts:
chapters:
- file: docs/ModelZoo
- file: docs/recipes/UsingModelZooPupil
- - file: docs/recipes/MegaDetectorDLCLive
- caption: 🧑🍳 Cookbook (detailed helper guides)
chapters:
- - file: docs/tutorial
- file: docs/convert_maDLC
- file: docs/recipes/OtherData
- file: docs/recipes/io
diff --git a/conda-environments/DEEPLABCUT.yaml b/conda-environments/DEEPLABCUT.yaml
index 66edc47212..5b5c77d4c6 100644
--- a/conda-environments/DEEPLABCUT.yaml
+++ b/conda-environments/DEEPLABCUT.yaml
@@ -9,9 +9,13 @@
#Licensed under GNU Lesser General Public License v3.0
#
# DeepLabCut environment
-# FIRST: INSTALL CORRECT DRIVER for GPU, see https://stackoverflow.com/questions/30820513/what-is-the-correct-version-of-cuda-for-my-nvidia-driver/30820690
#
-# AFTER THIS FILE IS INSTALLED, if you have a GPU be sure to install cudnn from conda-forge: conda install cudnn -c conda-forge
+# FIRST: If you have an NVIDIA GPU and want to use it, check that you have drivers installed!
+# To check if your GPUs are visible to PyTorch (and thus DeepLabCut), run:
+# >>> python -c "import torch; print(torch.cuda.is_available())"
+#
+# If "False" is printed, PyTorch (and thus DeepLabCut) cannot access your GPU. For
+# more information, see: https://pytorch.org/get-started/locally/
#
# install: conda env create -f DEEPLABCUT.yaml
# update: conda env update -f DEEPLABCUT.yaml
@@ -29,4 +33,6 @@ dependencies:
- ffmpeg
- pytables==3.8.0
- pip:
+ - torch
+ - torchvision
- "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut[gui,modelzoo,wandb]"
diff --git a/conda-environments/DEEPLABCUT_M1.yaml b/conda-environments/DEEPLABCUT_M1.yaml
deleted file mode 100644
index 47e8c02572..0000000000
--- a/conda-environments/DEEPLABCUT_M1.yaml
+++ /dev/null
@@ -1,43 +0,0 @@
-# DEEPLABCUT_M1.yaml
-
-#DeepLabCut2.0 Toolbox (deeplabcut.org)
-#© A. & M.W. Mathis Labs
-#https://github.com/DeepLabCut/DeepLabCut
-#Please see AUTHORS for contributors.
-
-#https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
-#Licensed under GNU Lesser General Public License v3.0
-#
-# DeepLabCut M1/M2 (Apple Silicon) environment instructions
-#
-# We'll get the miniconda M1 bash installer, as explained in
-# https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html
-#
-# In the Terminal, run the following commands:
-# wget https://repo.anaconda.com/miniconda/Miniconda3-py39_4.12.0-MacOSX-arm64.sh -O ~/miniconda.sh
-# bash ~/miniconda.sh -b -p $HOME/miniconda
-# source ~/miniconda/bin/activate
-# conda init zsh
-#
-# Then, `git clone DeepLabCut`, and run:
-#
-# conda env create -f conda-environments/DEEPLABCUT_M1.yaml
-#
-# Next, activate the environment, and launch DLC with pythonw -m deeplabcut
-
-name: DEEPLABCUT_M1
-channels:
- - conda-forge
- - defaults
-dependencies:
- - python=3.10
- - pip
- - ipython
- - jupyter
- - nb_conda
- - notebook<7.0.0
- - python.app
- - ffmpeg
- - apple::tensorflow-deps
- - pip:
- - "deeplabcut[gui,apple_mchips]"
diff --git a/deeplabcut/__init__.py b/deeplabcut/__init__.py
index 2da4b6a9f5..72dac1e3ce 100644
--- a/deeplabcut/__init__.py
+++ b/deeplabcut/__init__.py
@@ -12,10 +12,6 @@
import os
-# Suppress tensorflow warning messages
-import tensorflow as tf
-
-tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
DEBUG = True and "DEBUG" in os.environ and os.environ["DEBUG"]
from deeplabcut.version import __version__, VERSION
@@ -34,6 +30,7 @@
"DLC loaded in light mode; you cannot use any GUI (labeling, relabeling and standalone GUI)"
)
+from deeplabcut.core.engine import Engine
from deeplabcut.create_project import (
create_new_project,
create_new_project_3d,
@@ -49,6 +46,7 @@
mergeandsplit,
)
from deeplabcut.generate_training_dataset import (
+ create_training_dataset_from_existing_split,
create_training_model_comparison,
create_multianimaltraining_dataset,
)
@@ -60,6 +58,9 @@
dropduplicatesinannotatinfiles,
dropunlabeledframes,
)
+
+from deeplabcut.modelzoo.video_inference import video_inference_superanimal
+
from deeplabcut.utils import (
create_labeled_video,
create_video_with_all_detections,
@@ -92,13 +93,14 @@
)
# Train, evaluate & predict functions / all require TF
-from deeplabcut.pose_estimation_tensorflow import (
+from deeplabcut.compat import (
train_network,
return_train_network_path,
evaluate_network,
return_evaluate_network_data,
analyze_videos,
create_tracking_dataset,
+ analyze_images,
analyze_time_lapse_frames,
convert_detections2tracklets,
extract_maps,
@@ -107,7 +109,6 @@
visualize_paf,
extract_save_all_maps,
export_model,
- video_inference_superanimal,
)
diff --git a/deeplabcut/__main__.py b/deeplabcut/__main__.py
index 93b3f44b64..8d8c782a74 100644
--- a/deeplabcut/__main__.py
+++ b/deeplabcut/__main__.py
@@ -9,20 +9,25 @@
# Licensed under GNU Lesser General Public License v3.0
#
-try:
- import PySide6
+def main():
+ try:
+ import PySide6
- lite = False
-except ModuleNotFoundError:
- lite = True
+ lite = False
+ except ModuleNotFoundError:
+ lite = True
-# if module is executed directly (i.e. `python -m deeplabcut.__init__`) launch straight into the GUI
-if not lite:
- print("Starting GUI...")
- from deeplabcut.gui.launch_script import launch_dlc
+ # if module is executed directly (i.e. `python -m deeplabcut.__init__`) launch straight into the GUI
+ if not lite:
+ print("Starting GUI...")
+ from deeplabcut.gui.launch_script import launch_dlc
- launch_dlc()
-else:
- print(
- "You installed DLC lite, thus GUI's cannot be used. If you need GUI support please: pip install 'deeplabcut[gui]''"
- )
+ launch_dlc()
+ else:
+ print(
+ "You installed DLC lite, thus GUI's cannot be used. If you need GUI support please: pip install 'deeplabcut[gui]''"
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/deeplabcut/benchmark/__init__.py b/deeplabcut/benchmark/__init__.py
index 2e70eae786..a5ddd3c4a9 100644
--- a/deeplabcut/benchmark/__init__.py
+++ b/deeplabcut/benchmark/__init__.py
@@ -85,7 +85,11 @@ def evaluate(
continue
benchmark = benchmark_cls()
for name in benchmark.names():
- if Result(method_name=name, benchmark_name=benchmark_cls.name) in results:
+ if Result(
+ code=benchmark.code,
+ method_name=name,
+ benchmark_name=benchmark_cls.name,
+ ) in results:
continue
else:
result = benchmark.evaluate(name, on_error=on_error)
diff --git a/deeplabcut/benchmark/base.py b/deeplabcut/benchmark/base.py
index 49140c4436..ef6894ac3b 100644
--- a/deeplabcut/benchmark/base.py
+++ b/deeplabcut/benchmark/base.py
@@ -59,7 +59,7 @@ def get_predictions(self):
raise NotImplementedError()
def __init__(self):
- keys = ["name", "keypoints", "ground_truth", "metadata"]
+ keys = ["code", "name", "keypoints", "ground_truth", "metadata"]
for key in keys:
if not hasattr(self, key):
raise NotImplementedError(
@@ -110,6 +110,7 @@ def evaluate(self, name: str, on_error="raise"):
else:
raise NotImplementedError() from exception
return Result(
+ code=self.code,
method_name=name,
benchmark_name=self.name,
mean_avg_precision=mean_avg_precision,
@@ -139,6 +140,7 @@ def _validate_predictions(self, name: str, predictions: dict) -> dict:
class Result:
"""Benchmark result."""
+ code: str
method_name: str
benchmark_name: str
root_mean_squared_error: float = float("nan")
@@ -146,6 +148,7 @@ class Result:
benchmark_version: str = __version__
_export_mapping = dict(
+ code="code",
benchmark_name="benchmark",
method_name="method",
benchmark_version="version",
diff --git a/deeplabcut/benchmark/benchmarks.py b/deeplabcut/benchmark/benchmarks.py
index ee18e215c6..4068c29cf2 100644
--- a/deeplabcut/benchmark/benchmarks.py
+++ b/deeplabcut/benchmark/benchmarks.py
@@ -96,6 +96,16 @@ def compute_pose_map(self, results_objects):
symmetric_kpts=[(0, 4), (1, 3)],
)
+ def _validate_predictions(self, name: str, predictions: dict) -> dict:
+ """Fixes filenames for predictions made on old versions of the dataset"""
+ return super()._validate_predictions(
+ name,
+ {
+ k.replace("Dummy", "D").replace("Dead pup", "DP"): v
+ for k, v in predictions.items()
+ },
+ )
+
class MarmosetBenchmark(deeplabcut.benchmark.base.Benchmark):
"""Dataset with two marmosets.
diff --git a/deeplabcut/benchmark/metrics.py b/deeplabcut/benchmark/metrics.py
index ba735a9db7..e73eb4cba2 100644
--- a/deeplabcut/benchmark/metrics.py
+++ b/deeplabcut/benchmark/metrics.py
@@ -29,8 +29,7 @@
import pandas as pd
import deeplabcut.benchmark.utils
-from deeplabcut.pose_estimation_tensorflow.core import evaluate_multianimal
-from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils
+from deeplabcut.core import inferenceutils, crossvalutils
from deeplabcut.utils.conversioncode import guarantee_multiindex_rows
@@ -99,7 +98,7 @@ def calc_prediction_errors(preds, gt):
if visible.size and xy_pred_.size:
# Pick the predictions closest to ground truth,
# rather than the ones the model has most confident in.
- neighbors = evaluate_multianimal._find_closest_neighbors(
+ neighbors = crossvalutils.find_closest_neighbors(
xy_gt_[visible], xy_pred_, k=3
)
found = neighbors != -1
@@ -213,6 +212,7 @@ def calc_map_from_obj(
oks_sigma,
margin=margin,
symmetric_kpts=symmetric_kpts,
+ greedy_matching=True,
)
return oks["mAP"]
diff --git a/deeplabcut/benchmark/mot.py b/deeplabcut/benchmark/mot.py
new file mode 100644
index 0000000000..32c99b6987
--- /dev/null
+++ b/deeplabcut/benchmark/mot.py
@@ -0,0 +1,150 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+from __future__ import annotations
+
+import warnings
+
+import motmetrics as mm
+import numpy as np
+import pandas as pd
+from numpy.typing import NDArray
+
+from deeplabcut.core import trackingutils
+
+
+def _convert_bboxes_to_xywh(bboxes: NDArray, inplace: bool = False) -> NDArray:
+ w = bboxes[:, 2] - bboxes[:, 0]
+ h = bboxes[:, 3] - bboxes[:, 1]
+ if not inplace:
+ new_bboxes = bboxes.copy()
+ new_bboxes[:, 2] = w
+ new_bboxes[:, 3] = h
+ return new_bboxes
+ bboxes[:, 2] = w
+ bboxes[:, 3] = h
+
+
+def reconstruct_bboxes_from_bodyparts(
+ data: pd.DataFrame, margin: float, to_xywh: bool = False
+) -> NDArray:
+ x = data.xs("x", axis=1, level="coords")
+ y = data.xs("y", axis=1, level="coords")
+ p = data.xs("likelihood", axis=1, level="coords")
+ xy = np.stack([x, y], axis=2)
+ bboxes = np.full((data.shape[0], 5), np.nan)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ bboxes[:, :2] = np.nanmin(xy, axis=1) - margin
+ bboxes[:, 2:4] = np.nanmax(xy, axis=1) + margin
+ bboxes[:, 4] = np.nanmean(p, axis=1)
+ if to_xywh:
+ _convert_bboxes_to_xywh(bboxes, inplace=True)
+ return bboxes
+
+
+def reconstruct_all_bboxes(
+ data: pd.DataFrame, margin: float, to_xywh: bool = False
+) -> NDArray:
+ animals = data.columns.get_level_values("individuals").unique().tolist()
+ try:
+ animals.remove("single")
+ except ValueError:
+ pass
+ bboxes = np.full((len(animals), data.shape[0], 5), np.nan)
+ for n, animal in enumerate(animals):
+ bboxes[n] = reconstruct_bboxes_from_bodyparts(
+ data.xs(animal, axis=1, level="individuals"), margin, to_xywh
+ )
+ return bboxes
+
+
+def compute_mot_metrics(
+ h5_file_gt: str,
+ h5_file_pred: str,
+ tracker_type: str = "bbox",
+ **kwargs,
+) -> mm.MOTAccumulator:
+ df_gt = pd.read_hdf(h5_file_gt)
+ df = pd.read_hdf(h5_file_pred)
+ if tracker_type == "bbox":
+ func = reconstruct_all_bboxes
+ elif tracker_type == "ellipse":
+ func = trackingutils.reconstruct_all_ellipses
+ else:
+ raise ValueError(f"Unrecognized tracker type {tracker_type}.")
+
+ trackers_gt = func(df_gt, **kwargs)
+ trackers = func(df, **kwargs)
+ return _compute_mot_metrics(
+ trackers_gt, trackers, tracker_type,
+ )
+
+
+def _compute_mot_metrics(
+ trackers_ground_truth: NDArray,
+ trackers: NDArray,
+ tracker_type: str = "bbox",
+) -> mm.MOTAccumulator:
+ if trackers_ground_truth.shape != trackers.shape:
+ raise ValueError(
+ "Dimensions mismatch. There must be as many `trackers_ground_truth` as there are `trackers`."
+ )
+
+ if tracker_type == "bbox":
+ sl = slice(0, 4)
+ cost_func = mm.distances.iou_matrix
+ elif tracker_type == "ellipse":
+ sl = slice(0, 5)
+
+ def cost_func(ellipses_gt, ellipses_hyp):
+ cost_matrix = np.zeros((len(ellipses_gt), len(ellipses_hyp)))
+ gt_el = [trackingutils.Ellipse(*e[:5]) for e in ellipses_gt]
+ hyp_el = [trackingutils.Ellipse(*e[:5]) for e in ellipses_hyp]
+ for i, el in enumerate(gt_el):
+ for j, tracker in enumerate(hyp_el):
+ cost_matrix[i, j] = 1 - el.calc_similarity_with(tracker)
+ return cost_matrix
+
+ else:
+ raise ValueError(f"Unrecognized tracker type {tracker_type}.")
+
+ ids = np.arange(trackers_ground_truth.shape[0])
+ acc = mm.MOTAccumulator(auto_id=True)
+ for i in range(trackers_ground_truth.shape[1]):
+ trackers_gt = trackers_ground_truth[:, i, sl]
+ trackers_hyp = trackers[:, i, sl]
+ empty_gt = np.isnan(trackers_gt).any(axis=1)
+ empty_hyp = np.isnan(trackers_hyp).any(axis=1)
+ trackers_gt = trackers_gt[~empty_gt]
+ trackers_hyp = trackers_hyp[~empty_hyp]
+ cost = cost_func(trackers_gt, trackers_hyp)
+ acc.update(ids[~empty_gt], ids[~empty_hyp], cost)
+ return acc
+
+
+def print_all_metrics(
+ accumulators: list[mm.MOTAccumulator], all_params: list[str] | None = None
+):
+ if not all_params:
+ names = [f"iter{i + 1}" for i in range(len(accumulators))]
+ else:
+ s = "_".join("{}" for _ in range(len(all_params[0])))
+ names = [s.format(*params.values()) for params in all_params]
+ mh = mm.metrics.create()
+ summary = mh.compute_many(
+ accumulators, metrics=mm.metrics.motchallenge_metrics, names=names
+ )
+ strsummary = mm.io.render_summary(
+ summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names
+ )
+ print(strsummary)
+ return summary
diff --git a/deeplabcut/compat.py b/deeplabcut/compat.py
new file mode 100644
index 0000000000..b5973cfbcd
--- /dev/null
+++ b/deeplabcut/compat.py
@@ -0,0 +1,1968 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Compatibility file for methods available with either PyTorch or Tensorflow"""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Iterable
+
+import numpy as np
+from ruamel.yaml import YAML
+
+import deeplabcut.core.visualization as visualization
+from deeplabcut.core.engine import Engine
+from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine
+
+DEFAULT_ENGINE = Engine.PYTORCH
+
+
+def get_project_engine(cfg: dict) -> Engine:
+ """
+ Args:
+ cfg: the project configuration file
+
+ Returns:
+ the engine specified for the project, or the default engine if none is specified
+ """
+ if cfg.get("engine") is not None:
+ return Engine(cfg["engine"])
+
+ return DEFAULT_ENGINE
+
+
+def get_available_aug_methods(engine: Engine) -> tuple[str, ...]:
+ """
+ Args:
+ engine: the engine for which augmentation methods should be returned
+
+ Returns:
+ the augmentations available for the given engine, where the first one is the
+ default method to use
+
+ Raises:
+ RuntimeError: if no augmentations methods are defined for the given engine
+ """
+ if engine == Engine.TF:
+ return "imgaug", "default", "deterministic", "scalecrop", "tensorpack"
+ elif engine == Engine.PYTORCH:
+ return ("albumentations",)
+
+ raise RuntimeError(f"Unknown augmentation for engine: {engine}")
+
+
+def train_network(
+ config: str | Path,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ max_snapshots_to_keep: int | None = None,
+ displayiters: int | None = None,
+ saveiters: int | None = None,
+ maxiters: int | None = None,
+ epochs: int | None = None,
+ save_epochs: int | None = None,
+ allow_growth: bool = True,
+ gputouse: str | None = None,
+ autotune: bool = False,
+ keepdeconvweights: bool = True,
+ modelprefix: str = "",
+ superanimal_name: str = "",
+ superanimal_transfer_learning: bool = False,
+ engine: Engine | None = None,
+ device: str | None = None,
+ snapshot_path: str | Path | None = None,
+ detector_path: str | Path | None = None,
+ batch_size: int | None = None,
+ detector_batch_size: int | None = None,
+ detector_epochs: int | None = None,
+ detector_save_epochs: int | None = None,
+ pose_threshold: float | None = 0.1,
+ pytorch_cfg_updates: dict | None = None,
+):
+ """
+ Trains the network with the labels in the training dataset.
+
+ Parameters
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ shuffle: int, optional, default=1
+ Integer value specifying the shuffle index to select for training.
+
+ trainingsetindex: int, optional, default=0
+ Integer specifying which TrainingsetFraction to use.
+ Note that TrainingFraction is a list in config.yaml.
+
+ max_snapshots_to_keep: int or None
+ Sets how many snapshots are kept, i.e. states of the trained network. Every
+ saving iteration many times a snapshot is stored, however only the last
+ ``max_snapshots_to_keep`` many are kept! If you change this to None, then all
+ are kept.
+ See: https://github.com/DeepLabCut/DeepLabCut/issues/8#issuecomment-387404835
+
+ displayiters: optional, default=None
+ This variable is actually set in ``pose_config.yaml``. However, you can
+ overwrite it with this hack. Don't use this regularly, just if you are too lazy
+ to dig out the ``pose_config.yaml`` file for the corresponding project. If
+ ``None``, the value from there is used, otherwise it is overwritten!
+
+ saveiters: optional, default=None
+ Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``:
+ you can use ``save_epochs``).
+ This variable is actually set in ``pose_config.yaml``. However, you can
+ overwrite it with this hack. Don't use this regularly, just if you are too lazy
+ to dig out the ``pose_config.yaml`` file for the corresponding project.
+ If ``None``, the value from there is used, otherwise it is overwritten!
+
+ maxiters: optional, default=None
+ Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``:
+ you can use ``epochs``).
+ This variable is actually set in ``pose_config.yaml``. However, you can
+ overwrite it with this hack. Don't use this regularly, just if you are too lazy
+ to dig out the ``pose_config.yaml`` file for the corresponding project.
+ If ``None``, the value from there is used, otherwise it is overwritten!
+
+ epochs: optional, default=None
+ Only for the PyTorch engine (equivalent to the `maxiters` parameter for the
+ TensorFlow engine). The maximum number of epochs to train the model for. If
+ None, the value will be read from the `pytorch_config.yaml` file. An epoch is a
+ single pass through the training dataset, which means your model has seen each
+ training image exactly once. So if you have 64 training images for your network,
+ an epoch is 64 iterations with batch size 1 (or 32 iterations with batch size 2,
+ 16 with batch size 4, etc.).
+
+ save_epochs: optional, default=None
+ Only for the PyTorch engine (equivalent to the `saveiters` parameter for the
+ TensorFlow engine). The number of epochs between each snapshot save. If
+ None, the value will be read from the `pytorch_config.yaml` file.
+
+ allow_growth: bool, optional, default=True.
+ Only for the TensorFlow engine.
+ For some smaller GPUs the memory issues happen. If ``True``, the memory
+ allocator does not pre-allocate the entire specified GPU memory region, instead
+ starting small and growing as needed.
+ See issue: https://forum.image.sc/t/how-to-stop-running-out-of-vram/30551/2
+
+ gputouse: optional, default=None
+ Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``:
+ you can use ``device``).
+ Natural number indicating the number of your GPU (see number in nvidia-smi).
+ If you do not have a GPU put None.
+ See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ autotune: bool, optional, default=False
+ Only for the TensorFlow engine.
+ Property of TensorFlow, somehow faster if ``False``
+ (as Eldar found out, see https://github.com/tensorflow/tensorflow/issues/13317).
+
+ keepdeconvweights: bool, optional, default=True
+ Also restores the weights of the deconvolution layers (and the backbone) when
+ training from a snapshot. Note that if you change the number of bodyparts, you
+ need to set this to false for re-training.
+
+ modelprefix: str, optional, default=""
+ Directory containing the deeplabcut models to use when evaluating the network.
+ By default, the models are assumed to exist in the project folder.
+
+ superanimal_name: str, optional, default =""
+ Only for the TensorFlow engine. For the PyTorch engine, you need to specify
+ this through the ``weight_init`` when creating the training dataset.
+ Specified if transfer learning with superanimal is desired
+
+ superanimal_transfer_learning: bool, optional, default = False.
+ Only for the TensorFlow engine. For the PyTorch engine, you need to specify
+ this through the ``weight_init`` when creating the training dataset.
+ If set true, the training is transfer learning (new decoding layer). If set
+ false, and superanimal_name is True, then the training is fine-tuning (reusing
+ the decoding layer)
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ device: str, optional, default = None.
+ Only for the PyTorch engine. The device to run the training on (e.g. "cuda:0")
+
+ snapshot_path: str or Path, optional, default = None.
+ Only for the PyTorch engine. The path to the pose model snapshot to resume training from.
+
+ detector_path: str or Path, optional, default = None.
+ Only for the PyTorch engine. The path to the detector model snapshot to resume training from.
+
+ batch_size: int, optional, default = None.
+ Only for the PyTorch engine. The batch size to use while training.
+
+ detector_batch_size: int, optional, default = None.
+ Only for the PyTorch engine. The batch size to use while training the detector.
+
+ detector_epochs: int, optional, default = None.
+ Only for the PyTorch engine. The number of epochs to train the detector for.
+
+ detector_save_epochs: int, optional, default = None.
+ Only for the PyTorch engine. The number of epochs between each detector snapshot save.
+
+ pose_threshold: float, optional, default = 0.1.
+ Only for the PyTorch engine. Used for memory-replay. Pseudo-predictions with confidence lower
+ than this threshold are discarded for memory-replay
+
+ pytorch_cfg_updates: dict, optional, default = None.
+ A dictionary of updates to the pytorch config. The keys are the dot-separated
+ paths to the values to update in the config.
+ For example, to update the gpus to run the training on, you can use:
+ ```
+ pytorch_cfg_updates={"runner.gpus": [0,1,2,3]}
+ ```
+
+ Returns
+ -------
+ None
+
+ Examples
+ --------
+ To train the network for first shuffle of the training dataset
+
+ >>> deeplabcut.train_network('/analysis/project/reaching-task/config.yaml')
+
+ To train the network for second shuffle of the training dataset
+
+ >>> deeplabcut.train_network(
+ '/analysis/project/reaching-task/config.yaml',
+ shuffle=2,
+ keepdeconvweights=True,
+ )
+
+ To train the network for shuffle created with a PyTorch engine, while overriding the
+ number of epochs, batch size and other parameters.
+
+ >>> deeplabcut.train_network(
+ '/analysis/project/reaching-task/config.yaml',
+ shuffle=1,
+ batch_size=8,
+ epochs=100,
+ save_epochs=10,
+ display_iters=50,
+ )
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import train_network
+
+ if max_snapshots_to_keep is None:
+ max_snapshots_to_keep = 5
+
+ return train_network(
+ str(config),
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ max_snapshots_to_keep=max_snapshots_to_keep,
+ displayiters=displayiters,
+ saveiters=saveiters,
+ maxiters=maxiters,
+ allow_growth=allow_growth,
+ gputouse=gputouse,
+ autotune=autotune,
+ keepdeconvweights=keepdeconvweights,
+ superanimal_name=superanimal_name,
+ superanimal_transfer_learning=superanimal_transfer_learning,
+ modelprefix=modelprefix,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis import train_network
+
+ return train_network(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ modelprefix=modelprefix,
+ device=device,
+ snapshot_path=snapshot_path,
+ detector_path=detector_path,
+ load_head_weights=keepdeconvweights,
+ batch_size=batch_size,
+ epochs=epochs,
+ save_epochs=save_epochs,
+ detector_batch_size=detector_batch_size,
+ detector_epochs=detector_epochs,
+ detector_save_epochs=detector_save_epochs,
+ display_iters=displayiters,
+ max_snapshots_to_keep=max_snapshots_to_keep,
+ pose_threshold=pose_threshold,
+ pytorch_cfg_updates=pytorch_cfg_updates,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def return_train_network_path(
+ config,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ modelprefix: str = "",
+ engine: Engine | None = None,
+) -> tuple[Path, Path, Path]:
+ """
+ Returns the training and test pose config file names as well as the folder where the
+ snapshot is
+
+ Parameters
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ shuffle: int
+ Integer value specifying the shuffle index to select for training.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note
+ that TrainingFraction is a list in config.yaml).
+
+ modelprefix: str, optional
+ Directory containing the deeplabcut models to use when evaluating the network.
+ By default, the models are assumed to exist in the project folder.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Returns the triple: trainposeconfigfile, testposeconfigfile, snapshotfolder
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import return_train_network_path
+
+ return return_train_network_path(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ modelprefix=modelprefix,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ return_train_network_path,
+ )
+
+ return return_train_network_path(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ modelprefix=modelprefix,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def evaluate_network(
+ config: str | Path,
+ Shuffles: Iterable[int] = (1,),
+ trainingsetindex: int | str = 0,
+ plotting: bool | str = False,
+ show_errors: bool = True,
+ comparisonbodyparts: str | list[str] = "all",
+ gputouse: str | None = None,
+ rescale: bool = False,
+ modelprefix: str = "",
+ per_keypoint_evaluation: bool = False,
+ snapshots_to_evaluate: list[str] | None = None,
+ pcutoff: float | list[float] | dict[str, float] | None = None,
+ engine: Engine | None = None,
+ **torch_kwargs,
+):
+ """Evaluates the network.
+
+ Evaluates the network based on the saved models at different stages of the training
+ network. The evaluation results are stored in the .h5 and .csv file under the
+ subdirectory 'evaluation_results'. Change the snapshotindex parameter in the config
+ file to 'all' in order to evaluate all the saved models.
+
+ Parameters
+ ----------
+ config : string
+ Full path of the config.yaml file.
+
+ Shuffles: list, optional, default=[1]
+ List of integers specifying the shuffle indices of the training dataset.
+
+ trainingsetindex: int or str, optional, default=0
+ Integer specifying which "TrainingsetFraction" to use.
+ Note that "TrainingFraction" is a list in config.yaml. This variable can also
+ be set to "all".
+
+ plotting: bool or str, optional, default=False
+ Plots the predictions on the train and test images.
+ If provided it must be either ``True``, ``False``, ``"bodypart"``, or
+ ``"individual"``. Setting to ``True`` defaults as ``"bodypart"`` for
+ multi-animal projects.
+ If a detector is used, the predicted bounding boxes will also be plotted.
+
+ show_errors: bool, optional, default=True
+ Display train and test errors.
+
+ comparisonbodyparts: str or list, optional, default="all"
+ The average error will be computed for those body parts only.
+ The provided list has to be a subset of the defined body parts.
+
+ gputouse: int or None, optional, default=None
+ Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a
+ GPU put `None``.
+ See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ rescale: bool, optional, default=False
+ Evaluate the model at the ``'global_scale'`` variable (as set in the
+ ``pose_config.yaml`` file for a particular project). I.e. every image will be
+ resized according to that scale and prediction will be compared to the resized
+ ground truth. The error will be reported in pixels at rescaled to the
+ *original* size. I.e. For a [200,200] pixel image evaluated at
+ ``global_scale=.5``, the predictions are calculated on [100,100] pixel images,
+ compared to 1/2*ground truth and this error is then multiplied by 2!.
+ The evaluation images are also shown for the original size!
+
+ modelprefix: str, optional, default=""
+ Directory containing the deeplabcut models to use when evaluating the network.
+ By default, the models are assumed to exist in the project folder.
+
+ per_keypoint_evaluation: bool, default=False
+ Compute the train and test RMSE for each keypoint, and save the results to
+ a {model_name}-keypoint-results.csv in the evalution-results folder
+
+ snapshots_to_evaluate: List[str], optional, default=None
+ List of snapshot names to evaluate (e.g. ["snapshot-5000", "snapshot-7500"]).
+
+ pcutoff: float | list[float] | dict[str, float] | None, default=None
+ Only for the PyTorch engine. For the TensorFlow engine, please set the pcutoff
+ in the `config.yaml` file.
+ The cutoff to use for computing evaluation metrics. When `None` (default), the
+ cutoff will be loaded from the project config. If a list is provided, there
+ should be one value for each bodypart and one value for each unique bodypart
+ (if there are any). If a dict is provided, the keys should be bodyparts
+ mapping to pcutoff values for each bodypart. Bodyparts that are not defined
+ in the dict will have pcutoff set to 0.6.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ torch_kwargs:
+ You can add any keyword arguments for the deeplabcut.pose_estimation_pytorch
+ evaluate_network function here. These arguments are passed to the downstream
+ function. Available parameters are `snapshotindex`, which overrides the
+ `snapshotindex` parameter in the project configuration file. For top-down models
+ the `detector_snapshot_index` parameter can override the index of the detector
+ to use for evaluation in the project configuration file.
+
+ Returns
+ -------
+ None
+
+ Examples
+ --------
+ If you do not want to plot and evaluate with shuffle set to 1.
+
+ >>> deeplabcut.evaluate_network(
+ '/analysis/project/reaching-task/config.yaml', Shuffles=[1],
+ )
+
+ If you want to plot and evaluate with shuffle set to 0 and 1.
+
+ >>> deeplabcut.evaluate_network(
+ '/analysis/project/reaching-task/config.yaml',
+ Shuffles=[0, 1],
+ plotting=True,
+ )
+
+ If you want to plot assemblies for a maDLC project
+
+ >>> deeplabcut.evaluate_network(
+ '/analysis/project/reaching-task/config.yaml',
+ Shuffles=[1],
+ plotting="individual",
+ )
+
+ If you have a PyTorch model for which you want to set a different p-cutoff for
+ "left_ear" and "right_ear" bodyparts, and keep the one set in the project config
+ for other bodyparts:
+
+ >>> deeplabcut.evaluate_network(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> Shuffles=[0, 1],
+ >>> pcutoff={"left_ear": 0.8, "right_ear": 0.8},
+ >>> )
+
+ Note: This defaults to standard plotting for single-animal projects.
+ """
+ if engine is None:
+ cfg = _load_config(config)
+ engines = set()
+ for shuffle in Shuffles:
+ engines.add(
+ get_shuffle_engine(
+ cfg,
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+ )
+ if len(engines) == 0:
+ raise ValueError(
+ f"You must pass at least one shuffle to evaluate (had {list(Shuffles)})"
+ )
+ elif len(engines) > 1:
+ raise ValueError(
+ f"All shuffles must have the same engine (found {list(engines)})"
+ )
+ engine = engines.pop()
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import evaluate_network
+
+ return evaluate_network(
+ str(config),
+ Shuffles=Shuffles,
+ trainingsetindex=trainingsetindex,
+ plotting=plotting,
+ show_errors=show_errors,
+ comparisonbodyparts=comparisonbodyparts,
+ gputouse=gputouse,
+ rescale=rescale,
+ modelprefix=modelprefix,
+ per_keypoint_evaluation=per_keypoint_evaluation,
+ snapshots_to_evaluate=snapshots_to_evaluate,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis import evaluate_network
+
+ _update_device(gputouse, torch_kwargs)
+ return evaluate_network(
+ config,
+ shuffles=Shuffles,
+ trainingsetindex=trainingsetindex,
+ plotting=plotting,
+ show_errors=show_errors,
+ comparison_bodyparts=comparisonbodyparts,
+ per_keypoint_evaluation=per_keypoint_evaluation,
+ modelprefix=modelprefix,
+ pcutoff=pcutoff,
+ **torch_kwargs,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def return_evaluate_network_data(
+ config: str,
+ shuffle: int = 0,
+ trainingsetindex: int = 0,
+ comparisonbodyparts: str | list[str] = "all",
+ Snapindex: str | int | None = None,
+ rescale: bool = False,
+ fulldata: bool = False,
+ show_errors: bool = True,
+ modelprefix: str = "",
+ returnjustfns: bool = True,
+ engine: Engine | None = None,
+):
+ """
+ Returns the results for (previously evaluated) network. deeplabcut.evaluate_network(..)
+ Returns list of (per model): [trainingsiterations,trainfraction,shuffle,trainerror,testerror,pcutoff,trainerrorpcutoff,testerrorpcutoff,Snapshots[snapindex],scale,net_type]
+
+ This function is only implemented for tensorflow models/shuffles, and will throw
+ an error if called with a PyTorch shuffle.
+
+ If fulldata=True, also returns (the complete annotation and prediction array)
+ Returns list of: (DataMachine, Data, data, trainIndices, testIndices, trainFraction, DLCscorer,comparisonbodyparts, cfg, Snapshots[snapindex])
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ shuffle: integer
+ integers specifying shuffle index of the training dataset. The default is 0.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). This
+ variable can also be set to "all".
+
+ comparisonbodyparts: list of bodyparts, Default is "all".
+ The average error will be computed for those body parts only (Has to be a subset of the body parts).
+
+ rescale: bool, default False
+ Evaluate the model at the 'global_scale' variable (as set in the test/pose_config.yaml file for a particular project). I.e. every
+ image will be resized according to that scale and prediction will be compared to the resized ground truth. The error will be reported
+ in pixels at rescaled to the *original* size. I.e. For a [200,200] pixel image evaluated at global_scale=.5, the predictions are calculated
+ on [100,100] pixel images, compared to 1/2*ground truth and this error is then multiplied by 2!. The evaluation images are also shown for the
+ original size!
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Examples
+ --------
+ If you do not want to plot
+ >>> deeplabcut._evaluate_network_data('/analysis/project/reaching-task/config.yaml', shuffle=[1])
+ --------
+ If you want to plot
+ >>> deeplabcut.evaluate_network('/analysis/project/reaching-task/config.yaml',shuffle=[1],plotting=True)
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import return_evaluate_network_data
+
+ return return_evaluate_network_data(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ comparisonbodyparts=comparisonbodyparts,
+ Snapindex=Snapindex,
+ rescale=rescale,
+ fulldata=fulldata,
+ show_errors=show_errors,
+ modelprefix=modelprefix,
+ returnjustfns=returnjustfns,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def analyze_videos(
+ config: str,
+ videos: list[str],
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ gputouse: str | None = None,
+ save_as_csv: bool = False,
+ in_random_order: bool = True,
+ destfolder: str | None = None,
+ batchsize: int = None,
+ cropping: list[int] | None = None,
+ TFGPUinference: bool = True,
+ dynamic: tuple[bool, float, int] = (False, 0.5, 10),
+ modelprefix: str = "",
+ robust_nframes: bool = False,
+ allow_growth: bool = False,
+ use_shelve: bool = False,
+ auto_track: bool = True,
+ n_tracks: int | None = None,
+ animal_names: list[str] | None = None,
+ calibrate: bool = False,
+ identity_only: bool = False,
+ use_openvino: str | None = None,
+ engine: Engine | None = None,
+ **torch_kwargs,
+):
+ """Makes prediction based on a trained network.
+
+ The index of the trained network is specified by parameters in the config file
+ (in particular the variable 'snapshotindex').
+
+ The labels are stored as MultiIndex Pandas Array, which contains the name of
+ the network, body part name, (x, y) label position in pixels, and the
+ likelihood for each frame per body part. These arrays are stored in an
+ efficient Hierarchical Data Format (HDF) in the same directory where the video
+ is stored. However, if the flag save_as_csv is set to True, the data can also
+ be exported in comma-separated values format (.csv), which in turn can be
+ imported in many programs, such as MATLAB, R, Prism, etc.
+
+ Parameters
+ ----------
+ config: str
+ Full path of the config.yaml file.
+
+ videos: list[str]
+ A list of strings containing the full paths to videos for analysis or a path to
+ the directory, where all the videos with same extension are stored.
+
+ videotype: str, optional, default=""
+ Checks for the extension of the video in case the input to the video is a
+ directory. Only videos with this extension are analyzed. If left unspecified,
+ videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept.
+
+ shuffle: int, optional, default=1
+ An integer specifying the shuffle index of the training dataset used for
+ training the network.
+
+ trainingsetindex: int, optional, default=0
+ Integer specifying which TrainingsetFraction to use.
+ By default the first (note that TrainingFraction is a list in config.yaml).
+
+ gputouse: int or None, optional, default=None
+ Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``:
+ you can use ``device``).
+ Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a
+ GPU put ``None``.
+ See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ save_as_csv: bool, optional, default=False
+ Saves the predictions in a .csv file.
+
+ in_random_order: bool, optional (default=True)
+ Whether or not to analyze videos in a random order.
+ This is only relevant when specifying a video directory in `videos`.
+
+ destfolder: string or None, optional, default=None
+ Specifies the destination folder for analysis data. If ``None``, the path of
+ the video is used. Note that for subsequent analysis this folder also needs to
+ be passed.
+
+ batchsize: int or None, optional, default=None
+ Currently not supported by the PyTorch engine.
+ Change batch size for inference; if given overwrites value in ``pose_cfg.yaml``.
+
+ cropping: list or None, optional, default=None
+ Currently not supported by the PyTorch engine.
+ List of cropping coordinates as [x1, x2, y1, y2].
+ Note that the same cropping parameters will then be used for all videos.
+ If different video crops are desired, run ``analyze_videos`` on individual
+ videos with the corresponding cropping coordinates.
+
+ TFGPUinference: bool, optional, default=True
+ Only for the TensorFlow engine.
+ Perform inference on GPU with TensorFlow code. Introduced in "Pretraining
+ boosts out-of-domain robustness for pose estimation" by Alexander Mathis,
+ Mert Yüksekgönül, Byron Rogers, Matthias Bethge, Mackenzie W. Mathis.
+ Source: https://arxiv.org/abs/1909.11229
+
+ dynamic: tuple(bool, float, int) triple containing (state, det_threshold, margin)
+ If the state is true, then dynamic cropping will be performed. That means that
+ if an object is detected (i.e. any body part > detectiontreshold), then object
+ boundaries are computed according to the smallest/largest x position and
+ smallest/largest y position of all body parts. This window is expanded by the
+ margin and from then on only the posture within this crop is analyzed (until the
+ object is lost, i.e. >> deeplabcut.analyze_videos(
+ 'C:\\myproject\\reaching-task\\config.yaml',
+ ['C:\\yourusername\\rig-95\\Videos\\reachingvideo1.avi'],
+ )
+
+ Analyzing a single video on Linux/MacOS
+
+ >>> deeplabcut.analyze_videos(
+ '/analysis/project/reaching-task/config.yaml',
+ ['/analysis/project/videos/reachingvideo1.avi'],
+ )
+
+ Analyze all videos of type ``avi`` in a folder
+
+ >>> deeplabcut.analyze_videos(
+ '/analysis/project/reaching-task/config.yaml',
+ ['/analysis/project/videos'],
+ videotype='.avi',
+ )
+
+ Analyze multiple videos
+
+ >>> deeplabcut.analyze_videos(
+ '/analysis/project/reaching-task/config.yaml',
+ [
+ '/analysis/project/videos/reachingvideo1.avi',
+ '/analysis/project/videos/reachingvideo2.avi',
+ ],
+ )
+
+ Analyze multiple videos with ``shuffle=2``
+
+ >>> deeplabcut.analyze_videos(
+ '/analysis/project/reaching-task/config.yaml',
+ [
+ '/analysis/project/videos/reachingvideo1.avi',
+ '/analysis/project/videos/reachingvideo2.avi',
+ ],
+ shuffle=2,
+ )
+
+ Analyze multiple videos with ``shuffle=2``, save results as an additional csv file
+
+ >>> deeplabcut.analyze_videos(
+ '/analysis/project/reaching-task/config.yaml',
+ [
+ '/analysis/project/videos/reachingvideo1.avi',
+ '/analysis/project/videos/reachingvideo2.avi',
+ ],
+ shuffle=2,
+ save_as_csv=True,
+ )
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import analyze_videos
+
+ kwargs = {}
+ if use_openvino is not None: # otherwise default comes from tensorflow API
+ kwargs["use_openvino"] = use_openvino
+
+ return analyze_videos(
+ config,
+ videos,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ gputouse=gputouse,
+ save_as_csv=save_as_csv,
+ in_random_order=in_random_order,
+ destfolder=destfolder,
+ batchsize=batchsize,
+ cropping=cropping,
+ TFGPUinference=TFGPUinference,
+ dynamic=dynamic,
+ modelprefix=modelprefix,
+ robust_nframes=robust_nframes,
+ allow_growth=allow_growth,
+ use_shelve=use_shelve,
+ auto_track=auto_track,
+ n_tracks=n_tracks,
+ animal_names=animal_names,
+ calibrate=calibrate,
+ identity_only=identity_only,
+ **kwargs,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis import analyze_videos
+
+ _update_device(gputouse, torch_kwargs)
+
+ if batchsize is not None:
+ if "batch_size" in torch_kwargs:
+ print(
+ f"You called analyze_videos with parameters ``batchsize={batchsize}"
+ f"`` and batch_size={torch_kwargs['batch_size']}. Only one is "
+ f"needed/used. Using batch size {torch_kwargs['batch_size']}"
+ )
+ else:
+ torch_kwargs["batch_size"] = batchsize
+
+ return analyze_videos(
+ config,
+ videos=videos,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ save_as_csv=save_as_csv,
+ in_random_order=in_random_order,
+ destfolder=destfolder,
+ dynamic=dynamic,
+ modelprefix=modelprefix,
+ use_shelve=use_shelve,
+ robust_nframes=robust_nframes,
+ auto_track=auto_track,
+ n_tracks=n_tracks,
+ animal_names=animal_names,
+ calibrate=calibrate,
+ identity_only=identity_only,
+ overwrite=False,
+ cropping=cropping,
+ **torch_kwargs,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def create_tracking_dataset(
+ config: str,
+ videos: list[str],
+ track_method: str,
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ gputouse: int | None = None,
+ destfolder: str | None = None,
+ batchsize: int | None = None,
+ cropping: list[int] | None = None,
+ TFGPUinference: bool = True,
+ modelprefix: str = "",
+ robust_nframes: bool = False,
+ n_triplets: int = 1000,
+ engine: Engine | None = None,
+) -> str:
+ """Creates a tracking dataset to train a ReID tracklet stitcher.
+
+ Parameters
+ ----------
+ config: str
+ Full path of the config.yaml file.
+
+ videos: list[str]
+ A list of strings containing the full paths to videos from which to create a
+ tracking dataset, or a path to the directory where all the videos with same
+ extension are stored.
+
+ track_method: str
+ Specifies the tracker used to generate the pose estimation data. Must be either
+ 'box', 'skeleton', or 'ellipse'.
+
+ videotype: str, optional, default=""
+ Checks for the extension of the video in case the input to the video is a
+ directory. Only videos with this extension are analyzed. If left unspecified,
+ videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept.
+
+ shuffle: int, optional, default=1
+ An integer specifying the shuffle index of the training dataset used for
+ training the network.
+
+ trainingsetindex: int, optional, default=0
+ Integer specifying which TrainingsetFraction to use.
+ By default the first (note that TrainingFraction is a list in config.yaml).
+
+ gputouse: int or None, optional, default=None
+ Only for the TensorFlow engine (for the PyTorch engine use ``device``).
+ Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a
+ GPU put ``None``. See:
+ https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ TFGPUinference: bool, optional, default=True
+ Only for the TensorFlow engine.
+ Perform inference on GPU with TensorFlow code. Introduced in "Pretraining
+ boosts out-of-domain robustness for pose estimation" by Alexander Mathis,
+ Mert Yüksekgönül, Byron Rogers, Matthias Bethge, Mackenzie W. Mathis.
+ Source: https://arxiv.org/abs/1909.11229
+
+ destfolder:
+ Specifies the destination folder for analysis data. If ``None``, the path of
+ the video is used. Note that for subsequent analysis this folder also needs to
+ be passed.
+
+ modelprefix: str, optional, default=""
+ Directory containing the deeplabcut models to use when evaluating the network.
+ By default, the models are assumed to exist in the project folder.
+
+ robust_nframes: bool, optional, default=False
+ Evaluate a video's number of frames in a robust manner.
+ This option is slower (as the whole video is read frame-by-frame),
+ but does not rely on metadata, hence its robustness against file corruption.
+
+ n_triplets: int, default=1000
+ The number of triplets to extract for the dataset.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Returns
+ -------
+ DLCScorer: str
+ the scorer used to analyze the videos
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import create_tracking_dataset
+
+ return create_tracking_dataset(
+ config,
+ videos,
+ track_method,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ gputouse=gputouse,
+ destfolder=destfolder,
+ batchsize=batchsize,
+ cropping=cropping,
+ TFGPUinference=TFGPUinference,
+ modelprefix=modelprefix,
+ robust_nframes=robust_nframes,
+ n_triplets=n_triplets,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis import create_tracking_dataset
+ return create_tracking_dataset(
+ config,
+ videos,
+ track_method,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ destfolder=destfolder,
+ batch_size=batchsize,
+ cropping=cropping,
+ modelprefix=modelprefix,
+ robust_nframes=robust_nframes,
+ n_triplets=n_triplets,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def analyze_images(
+ config: str | Path,
+ images: str | Path | list[str] | list[Path],
+ frame_type: str | None = None,
+ destfolder: str | Path | None = None,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ max_individuals: int | None = None,
+ device: str | None = None,
+ snapshot_index: int | None = None,
+ detector_snapshot_index: int | None = None,
+ save_as_csv: bool = False,
+ modelprefix: str = "",
+ plotting: bool | str = False,
+ pcutoff: float | None = None,
+ bbox_pcutoff: float | None = None,
+ plot_skeleton: bool = False,
+) -> dict[str, dict[str, np.ndarray | np.ndarray]]:
+ """Analyzes images with a DeepLabCut model and stores the output in an H5 file.
+
+ This method is only implemented for PyTorch models.
+
+ The labels are stored as Pandas DataFrame, which contains the name of the network,
+ body part name, (x, y) label position in pixels, and the likelihood for each frame
+ per body part.
+
+ Parameters
+ ----------
+ config : str, Path
+ Full path of the project's config.yaml file.
+
+ images: str, Path, list[str], list[Path]
+ The image(s) to run inference on. Can be the path to an image, the path
+ to a directory containing images, or a list of image paths or directories
+ containing images.
+
+ frame_type: string, optional
+ Filters the images to analyze to only the ones with the given suffix (e.g.
+ setting `frame_type`=".png" will only analyze ".png" images). The default
+ behavior analyzes all ".jpg", ".jpeg" and ".png" images.
+
+ destfolder: str, Path, optional
+ The directory where the predictions will be stored. If None, the predictions
+ will be stored in the same directory as the first image given in the `images`
+ argument (if it's a directory, that directory will be used; if it's an image,
+ the directory containing the image will be used).
+
+ shuffle: int, optional
+ An integer specifying the shuffle with which to run image analysis.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default, the first one
+ is used (note that TrainingFraction is a list in config.yaml).
+
+ max_individuals: int, optional
+ The maximum number of individuals to detect in each image. Set to the number of
+ individuals in the project if None.
+
+ device: str, optional
+ The CUDA device to use for training. If None, the device will be taken from the
+ ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. For
+ more information, see https://pytorch.org/docs/stable/notes/cuda.html
+
+ snapshot_index: int, optional
+ Index (starting at 0) of the snapshot to use for image analysis. To evaluate the
+ last one, use -1. Default uses the value set in the project config.
+
+ detector_snapshot_index: int, optional
+ Only for Top-Down PyTorch models. If defined, uses the detector with the given
+ index for pose estimation. To evaluate the last one, use -1. Default uses the
+ value set in the project config.
+
+ save_as_csv: bool, optional
+ Saves the predictions in a .csv file. The default is ``False``; if provided it
+ must be either ``True`` or ``False``.
+
+ modelprefix: str, optional
+ Directory containing the deeplabcut models to use when running image analysis.
+ By default, the models are assumed to exist in the project folder.
+
+ plotting: bool, str, default=False
+ Plots the predictions made by the model on the analyzed images. Results will be
+ stored in a folder named `LabeledImages_{scorer}`, where scorer is the name
+ of the model used to analyze the images. This folder will be in the same
+ directory as the file containing the predictions (either the given `destfolder`,
+ or the folder containing the first image to analyze).
+
+ If provided it must be either ``True``, ``False``, ``"bodypart"``, or
+ ``"individual"``. Setting to ``True`` defaults as ``"bodypart"`` for
+ multi-animal projects. If a detector is used, the predicted bounding boxes
+ will also be plotted.
+
+ pcutoff: float, optional, default=None
+ The cutoff score when plotting pose predictions. Must be None or in
+ (0, 1). If None, the pcutoff is read from the project configuration file.
+
+ bbox_pcutoff: float, optional, default=None
+ The cutoff score when plotting bounding box predictions. Must be
+ None or in (0, 1). If None, it is read from the project configuration file.
+
+ plot_skeleton: bool, default=False
+ If a skeleton is defined in the project's config.yaml, whether
+ to plot the skeleton connecting the predicted bodyparts on the images.
+
+ Returns
+ -------
+ A dictionary mapping image paths (as strings) to model predictions.
+
+ Examples
+ --------
+ If you want to analyze all frames in /analysis/project/my_images
+ >>> import deeplabcut
+ >>> deeplabcut.analyze_images(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> "/analysis/project/my_images",
+ >>> )
+ >>>
+
+ If you want to analyze two specific images with your shuffle 3 model:
+ >>> import deeplabcut
+ >>> deeplabcut.analyze_images(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> images=["image_001.png", "img_002.jpg"],
+ >>> shuffle=3,
+ >>> )
+ >>>
+
+ If you want to analyze frames in a folder, save them and plot predictions:
+ >>> import deeplabcut
+ >>> deeplabcut.analyze_images(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> "/analysis/project/my_images",
+ >>> shuffle=3,
+ >>> destfolder="/analysis/project/my_images_analyzed",
+ >>> plotting=True,
+ >>> )
+ >>>
+ --------
+ """
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch import analyze_images
+
+ return analyze_images(
+ config=config,
+ images=images,
+ frame_type=frame_type,
+ output_dir=destfolder,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ snapshot_index=snapshot_index,
+ detector_snapshot_index=detector_snapshot_index,
+ modelprefix=modelprefix,
+ device=device,
+ save_as_csv=save_as_csv,
+ max_individuals=max_individuals,
+ plotting=plotting,
+ pcutoff=pcutoff,
+ bbox_pcutoff=bbox_pcutoff,
+ plot_skeleton=plot_skeleton,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def analyze_time_lapse_frames(
+ config: str,
+ directory: str,
+ frametype: str = ".png",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ gputouse: int | None = None,
+ device: str | None = None,
+ save_as_csv: bool = False,
+ modelprefix: str = "",
+ engine: Engine | None = None,
+):
+ """
+ Analyzed all images (of type = frametype) in a folder and stores the output in one file.
+
+ You can crop the frames (before analysis), by changing 'cropping'=True and setting
+ 'x1','x2','y1','y2' in the config file.
+
+ Output: The labels are stored as MultiIndex Pandas Array, which contains the name
+ of the network, body part name, (x, y) label position in pixels, and the likelihood
+ for each frame per body part. These arrays are stored in an efficient Hierarchical
+ Data Format (HDF) in the same directory, where the video is stored. However, if the
+ flag save_as_csv is set to True, the data can also be exported in comma-separated
+ values format (.csv), which in turn can be imported in many programs, such as
+ MATLAB, R, Prism, etc.
+
+ Parameters
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ directory: string
+ Full path to directory containing the frames that shall be analyzed
+
+ frametype: string, optional
+ Checks for the file extension of the frames. Only images with this extension are
+ analyzed. The default is ``.png``
+
+ shuffle: int, optional
+ An integer specifying the shuffle index of the training dataset used for
+ training the network. The default is 1.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note
+ that TrainingFraction is a list in config.yaml).
+
+ gputouse: int, optional.
+ Only for TensorFlow models. For PyTorch models, please use `device`. Natural
+ number indicating the number of your GPU (see number in nvidia-smi). If you do
+ not have a GPU put None. See:
+ https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ device: str, optional
+ The CUDA device to use for training. If None, the device will be taken from the
+ ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. For
+ more information, see https://pytorch.org/docs/stable/notes/cuda.html
+
+ save_as_csv: bool, optional
+ Saves the predictions in a .csv file. The default is ``False``; if provided if
+ must be either ``True`` or ``False``
+
+ Examples
+ --------
+ If you want to analyze all frames in /analysis/project/timelapseexperiment1
+ >>> import deeplabcut
+ >>> deeplabcut.analyze_time_lapse_frames(
+ >>> '/analysis/project/reaching-task/config.yaml',
+ >>> '/analysis/project/timelapseexperiment1'
+ >>> )
+
+ --------
+
+ Note: for test purposes one can extract all frames from a video with ffmeg, e.g.
+ >>> ffmpeg -i testvideo.avi "thumb%04d.png"
+
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import analyze_time_lapse_frames
+
+ return analyze_time_lapse_frames(
+ config,
+ directory,
+ frametype=frametype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ gputouse=gputouse,
+ save_as_csv=save_as_csv,
+ modelprefix=modelprefix,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch import analyze_images
+
+ return analyze_images(
+ config=config,
+ images=directory,
+ output_dir=directory,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ device=_gpu_to_use_to_device(gputouse, device),
+ save_as_csv=save_as_csv,
+ modelprefix=modelprefix,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def convert_detections2tracklets(
+ config: str,
+ videos: list[str],
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ overwrite: bool = False,
+ destfolder: str | None = None,
+ ignore_bodyparts: list[str] | None = None,
+ inferencecfg: dict | None = None,
+ modelprefix: str = "",
+ greedy: bool = False,
+ calibrate: bool = False,
+ window_size: int = 0,
+ identity_only: int = False,
+ track_method: str = "",
+ engine: Engine | None = None,
+):
+ """
+ This should be called at the end of deeplabcut.analyze_videos for multianimal projects!
+
+ Parameters
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ videos : list
+ A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.
+
+ videotype: string, optional
+ Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed.
+ If left unspecified, videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept.
+
+ shuffle: int, optional
+ An integer specifying the shuffle index of the training dataset used for training the network. The default is 1.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml).
+
+ overwrite: bool, optional.
+ Overwrite tracks file i.e. recompute tracks from full detections and overwrite.
+
+ destfolder: string, optional
+ Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
+ folder also needs to be passed.
+
+ ignore_bodyparts: optional
+ List of body part names that should be ignored during tracking (advanced).
+ By default, all the body parts are used.
+
+ inferencecfg: Default is None.
+ Configuration file for inference (assembly of individuals). Ideally
+ should be obtained from cross validation (during evaluation). By default
+ the parameters are loaded from inference_cfg.yaml, but these get_level_values
+ can be overwritten.
+
+ calibrate: bool, optional (default=False)
+ If True, use training data to calibrate the animal assembly procedure.
+ This improves its robustness to wrong body part links,
+ but requires very little missing data.
+
+ window_size: int, optional (default=0)
+ Recurrent connections in the past `window_size` frames are
+ prioritized during assembly. By default, no temporal coherence cost
+ is added, and assembly is driven mainly by part affinity costs.
+
+ identity_only: bool, optional (default=False)
+ If True and animal identity was learned by the model,
+ assembly and tracking rely exclusively on identity prediction.
+
+ track_method: string, optional
+ Specifies the tracker used to generate the pose estimation data.
+ For multiple animals, must be either 'box', 'skeleton', or 'ellipse'
+ and will be taken from the config.yaml file if none is given.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Examples
+ --------
+ If you want to convert detections to tracklets:
+ >>> import deeplabcut
+ >>> deeplabcut.convert_detections2tracklets(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> ["/analysis/project/video1.mp4"],
+ >>> videotype='.mp4',
+ >>> )
+
+ If you want to convert detections to tracklets based on box_tracker:
+ >>> import deeplabcut
+ >>> deeplabcut.convert_detections2tracklets(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> ["/analysis/project/video1.mp4"],
+ >>> videotype=".mp4",
+ >>> track_method="box",
+ >>> )
+
+ --------
+
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import convert_detections2tracklets
+
+ return convert_detections2tracklets(
+ config,
+ videos,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ overwrite=overwrite,
+ destfolder=destfolder,
+ ignore_bodyparts=ignore_bodyparts,
+ inferencecfg=inferencecfg,
+ modelprefix=modelprefix,
+ greedy=greedy,
+ calibrate=calibrate,
+ window_size=window_size,
+ identity_only=identity_only,
+ track_method=track_method,
+ )
+
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis import convert_detections2tracklets
+
+ if greedy or calibrate or window_size:
+ raise NotImplementedError(
+ f"The 'greedy', 'calibrate' and 'window_size' option are not yet "
+ f"implemented with {engine}"
+ )
+
+ return convert_detections2tracklets(
+ config,
+ videos,
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ overwrite=overwrite,
+ destfolder=destfolder,
+ ignore_bodyparts=ignore_bodyparts,
+ inferencecfg=inferencecfg,
+ modelprefix=modelprefix,
+ identity_only=identity_only,
+ track_method=track_method,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def extract_maps(
+ config,
+ shuffle: int = 0,
+ trainingsetindex: int = 0,
+ gputouse: int | None = None,
+ device: str | None = None,
+ rescale: bool = False,
+ Indices: list[int] | None = None,
+ modelprefix: str = "",
+ engine: Engine | None = None,
+):
+ """
+ Extracts the scoremap, locref, partaffinityfields (if available).
+
+ Returns a dictionary indexed by: trainingsetfraction, snapshotindex, and imageindex
+ for those keys, each item contains: (image, scmap, locref, paf, bpt_names,
+ partaffinity_graph, imagename, True/False if this image was in trainingset).
+
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ shuffle: integer
+ integers specifying shuffle index of the training dataset. The default is 0.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note
+ that TrainingFraction is a list in config.yaml). This variable can also be set
+ to "all".
+
+ gputouse: int or None, optional, default=None
+ For the TensorFlow engine (for the PyTorch engine see ``device``). Specifies
+ the GPU to use (see number in ``nvidia-smi``). If you do not have a GPU put
+ ``None``. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ device: str or None, optional, default=None
+ The CUDA device to use for training. If None, the device will be taken from the
+ ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. See
+ https://pytorch.org/docs/stable/notes/cuda.html for more information.
+
+ rescale: bool, default False
+ Evaluate the model at the 'global_scale' variable (as set in the test/pose_config.yaml file for a particular project). I.e. every
+ image will be resized according to that scale and prediction will be compared to the resized ground truth. The error will be reported
+ in pixels at rescaled to the *original* size. I.e. For a [200,200] pixel image evaluated at global_scale=.5, the predictions are calculated
+ on [100,100] pixel images, compared to 1/2*ground truth and this error is then multiplied by 2!. The evaluation images are also shown for the
+ original size!
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Examples
+ --------
+ If you want to extract the data for image 0 and 103 (of the training set) for model trained with shuffle 0.
+ >>> deeplabcut.extract_maps(configfile,0,Indices=[0,103])
+
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import extract_maps
+
+ return extract_maps(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ gputouse=gputouse,
+ rescale=rescale,
+ Indices=Indices,
+ modelprefix=modelprefix,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch import extract_maps
+
+ return extract_maps(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ device=_gpu_to_use_to_device(gputouse, device),
+ rescale=rescale,
+ indices=Indices,
+ modelprefix=modelprefix,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def visualize_scoremaps(image: np.ndarray, scmap: np.ndarray):
+ """Plots scoremaps as an image overlay.
+
+ Args:
+ image: An image as a numpy array of shape (h, w, channels)
+ scmap: A scoremap of shape (h, w)
+
+ Returns:
+ The figure and axis on which the image scoremap was plot.
+ """
+ return visualization.visualize_scoremaps(image, scmap)
+
+
+def visualize_locrefs(
+ image: np.ndarray,
+ scmap: np.ndarray,
+ locref_x: np.ndarray,
+ locref_y: np.ndarray,
+ step: int = 5,
+ zoom_width: int = 0,
+):
+ """Plots a scoremap and the corresponding location refinement field on an image.
+
+ Args:
+ image: An image as a numpy array of shape (h, w, channels)
+ scmap: A scoremap of shape (h, w)
+ locref_x: The x-coordinate of the location refinement field, of shape (h, w)
+ locref_y: The y-coordinate of the location refinement field, of shape (h, w)
+ step: The step with which to plot the location refinement field.
+ zoom_width: The zoom width with which to plot the scoremaps.
+
+ Returns:
+ The figure and axis on which the image scoremap and locref field were plot.
+ """
+ return visualization.visualize_locrefs(
+ image, scmap, locref_x, locref_y, step=step, zoom_width=zoom_width
+ )
+
+
+def visualize_paf(
+ image: np.ndarray,
+ paf: np.ndarray,
+ step: int = 5,
+ colors: list | None = None,
+):
+ """Plots the PAF on top of the image.
+
+ Args:
+ image: Shape (height, width, channels). The image on which the model was run.
+ paf: Shape (height, width, 2 * len(paf_graph)). The PAF output by the model.
+ step: The step with which to plot the scoremaps.
+ colors: The colormap to use.
+
+ Returns:
+ The figure and axis on which the image PAF was plot.
+ """
+ return visualization.visualize_paf(image, paf, step=step, colors=colors)
+
+
+def extract_save_all_maps(
+ config,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ comparisonbodyparts: str | list[str] = "all",
+ extract_paf: bool = True,
+ all_paf_in_one: bool = True,
+ gputouse: int | None = None,
+ device: str | None = None,
+ rescale: bool = False,
+ Indices: list[int] | None = None,
+ modelprefix: str = "",
+ dest_folder: str = None,
+ snapshot_index: int | str | None = None,
+ detector_snapshot_index: int | str | None = None,
+ engine: Engine | None = None,
+):
+ """
+ Extracts the scoremap, location refinement field and part affinity field prediction of the model. The maps
+ will be rescaled to the size of the input image and stored in the corresponding model folder in /evaluation-results.
+
+ ----------
+ config : string
+ Full path of the config.yaml file as a string.
+
+ shuffle: integer
+ integers specifying shuffle index of the training dataset. The default is 1.
+
+ trainingsetindex: int, optional
+ Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). This
+ variable can also be set to "all".
+
+ comparisonbodyparts: list of bodyparts, Default is "all".
+ The average error will be computed for those body parts only (Has to be a subset of the body parts).
+
+ extract_paf : bool
+ Extract part affinity fields by default.
+ Note that turning it off will make the function much faster.
+
+ all_paf_in_one : bool
+ By default, all part affinity fields are displayed on a single frame.
+ If false, individual fields are shown on separate frames.
+
+ gputouse: int or None, optional, default=None
+ For the TensorFlow engine (for the PyTorch engine see ``device``). Specifies
+ the GPU to use (see number in ``nvidia-smi``). If you do not have a GPU put
+ ``None``. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+
+ device: str or None, optional, default=None
+ The CUDA device to use for training. If None, the device will be taken from the
+ ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. See
+ https://pytorch.org/docs/stable/notes/cuda.html for more information.
+
+ Indices: default None
+ For which images shall the scmap/locref and paf be computed? Give a list of images
+
+ nplots_per_row: int, optional (default=None)
+ Number of plots per row in grid plots. By default, calculated to approximate a squared grid of plots
+
+ snapshot_index: Only for PyTorch models. Index (starting at 0) of the snapshot we
+ want to extract maps with. To evaluate the last one, use -1. To extract maps
+ for all snapshots, use "all". Default uses the value set in the project config.
+
+ detector_snapshot_index: Only for TD PyTorch models. If defined, uses the detector
+ with the given index for pose estimation. To extract maps for all detector
+ snapshots, use "all". Default uses the value set in the project config.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Examples
+ --------
+ Calculated maps for images 0, 1 and 33.
+ >>> deeplabcut.extract_save_all_maps('/analysis/project/reaching-task/config.yaml', shuffle=1,Indices=[0,1,33])
+
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(config),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import extract_save_all_maps
+
+ return extract_save_all_maps(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ comparisonbodyparts=comparisonbodyparts,
+ extract_paf=extract_paf,
+ all_paf_in_one=all_paf_in_one,
+ gputouse=gputouse,
+ rescale=rescale,
+ Indices=Indices,
+ modelprefix=modelprefix,
+ dest_folder=dest_folder,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch import extract_save_all_maps
+
+ return extract_save_all_maps(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ comparison_bodyparts=comparisonbodyparts,
+ extract_paf=extract_paf,
+ all_paf_in_one=all_paf_in_one,
+ device=_gpu_to_use_to_device(gputouse, device),
+ rescale=rescale,
+ indices=Indices,
+ modelprefix=modelprefix,
+ snapshot_index=snapshot_index,
+ detector_snapshot_index=detector_snapshot_index,
+ dest_folder=dest_folder,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def export_model(
+ cfg_path: str,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ snapshotindex: int | None = None,
+ iteration: int = None,
+ TFGPUinference: bool = True,
+ overwrite: bool = False,
+ make_tar: bool = True,
+ wipepaths: bool = False,
+ without_detector: bool = False,
+ modelprefix: str = "",
+ engine: Engine | None = None,
+) -> None:
+ """Export DeepLabCut models for the model zoo or for live inference.
+
+ Saves the pose configuration, snapshot files, and frozen TF graph of the model to
+ directory named exported-models within the project directory (and an
+ `exported-models-pytorch` directory for PyTorch models).
+
+ Parameters
+ -----------
+
+ cfg_path : string
+ path to the DLC Project config.yaml file
+
+ shuffle : int, optional
+ the shuffle of the model to export. default = 1
+
+ trainingsetindex : int, optional
+ the index of the training fraction for the model you wish to export. default = 1
+
+ snapshotindex : int, optional
+ the snapshot index for the weights you wish to export. If None,
+ uses the snapshotindex as defined in 'config.yaml'. Default = None
+
+ iteration : int, optional
+ The model iteration (active learning loop) you wish to export. If None,
+ the iteration listed in the config file is used.
+
+ TFGPUinference : bool, optional
+ use the tensorflow inference model? Default = True
+ For inference using DeepLabCut-live, it is recommended to set TFGPIinference=False
+
+ overwrite : bool, optional
+ if the model you wish to export has already been exported, whether to overwrite. default = False
+
+ make_tar : bool, optional
+ Do you want to compress the exported directory to a tar file? Default = True
+ This is necessary to export to the model zoo, but not for live inference.
+
+ wipepaths : bool, optional
+ Removes the actual path of your project and the init_weights from pose_cfg.
+
+ without_detector: bool, optional
+ PyTorch engine only. Exports top-down models without the detector.
+
+ engine: Engine, optional, default = None.
+ The default behavior loads the engine for the shuffle from the metadata. You can
+ overwrite this by passing the engine as an argument, but this should generally
+ not be done.
+
+ Example:
+ --------
+ Export the first stored snapshot for model trained with shuffle 3:
+ >>> deeplabcut.export_model('/analysis/project/reaching-task/config.yaml',shuffle=3, snapshotindex=-1)
+ --------
+ """
+ if engine is None:
+ engine = get_shuffle_engine(
+ _load_config(cfg_path),
+ trainingsetindex=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.TF:
+ from deeplabcut.pose_estimation_tensorflow import export_model
+
+ return export_model(
+ cfg_path=cfg_path,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ snapshotindex=snapshotindex,
+ iteration=iteration,
+ TFGPUinference=TFGPUinference,
+ overwrite=overwrite,
+ make_tar=make_tar,
+ wipepaths=wipepaths,
+ modelprefix=modelprefix,
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis.export import export_model
+
+ return export_model(
+ config=cfg_path,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ snapshotindex=snapshotindex,
+ iteration=iteration,
+ overwrite=overwrite,
+ wipe_paths=wipepaths,
+ without_detector=without_detector,
+ modelprefix=modelprefix,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def _update_device(gpu_to_use: int | None, torch_kwargs: dict) -> None:
+ if "device" not in torch_kwargs and gpu_to_use is not None:
+ device = _gpu_to_use_to_device(gpu_to_use, device=None)
+ if device is not None:
+ torch_kwargs["device"] = device
+
+
+def _gpu_to_use_to_device(gpu_to_use: int | None, device: str | None) -> str | None:
+ if device is None and gpu_to_use is not None:
+ if isinstance(gpu_to_use, int):
+ device = f"cuda:{gpu_to_use}"
+ else:
+ device = gpu_to_use
+
+ return device
+
+
+def _load_config(config: str) -> dict:
+ config_path = Path(config)
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"Config {config} is not found. Please make sure that the file exists."
+ )
+
+ with open(config, "r") as f:
+ project_config = YAML(typ="safe", pure=True).load(f)
+
+ return project_config
diff --git a/deeplabcut/core/__init__.py b/deeplabcut/core/__init__.py
new file mode 100644
index 0000000000..117d127147
--- /dev/null
+++ b/deeplabcut/core/__init__.py
@@ -0,0 +1,10 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
diff --git a/deeplabcut/core/config.py b/deeplabcut/core/config.py
new file mode 100644
index 0000000000..1a638e48da
--- /dev/null
+++ b/deeplabcut/core/config.py
@@ -0,0 +1,74 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Simple helper methods related to configuration files stored in yaml files"""
+from __future__ import annotations
+
+from pathlib import Path
+from typing import Callable
+
+from ruamel.yaml import YAML
+
+
+def read_config_as_dict(config_path: str | Path) -> dict:
+ """
+ Args:
+ config_path: the path to the configuration file to load
+
+ Returns:
+ The configuration file with pure Python classes
+ """
+ with open(config_path, "r") as f:
+ cfg = YAML(typ="safe", pure=True).load(f)
+
+ return cfg
+
+
+def write_config(config_path: str | Path, config: dict, overwrite: bool = True) -> None:
+ """Writes a pose configuration file to disk
+
+ Args:
+ config_path: the path where the config should be saved
+ config: the config to save
+ overwrite: whether to overwrite the file if it already exists
+
+ Raises:
+ FileExistsError if overwrite=True and the file already exists
+ """
+ if not overwrite and Path(config_path).exists():
+ raise FileExistsError(
+ f"Cannot write to {config_path} - set overwrite=True to force"
+ )
+
+ with open(config_path, "w") as file:
+ YAML().dump(config, file)
+
+
+def pretty_print(
+ config: dict,
+ indent: int = 0,
+ print_fn: Callable[[str], None] | None = None,
+) -> None:
+ """Prints a model configuration in a pretty and readable way
+
+ Args:
+ config: the config to print
+ indent: the base indent on all keys
+ print_fn: custom function to call (simply calls ``print`` if None)
+ """
+ if print_fn is None:
+ print_fn = print
+
+ for k, v in config.items():
+ if isinstance(v, dict):
+ print_fn(f"{indent * ' '}{k}:")
+ pretty_print(v, indent + 2, print_fn=print_fn)
+ else:
+ print_fn(f"{indent * ' '}{k}: {v}")
diff --git a/deeplabcut/core/conversion_table.py b/deeplabcut/core/conversion_table.py
new file mode 100644
index 0000000000..e5d9679fa9
--- /dev/null
+++ b/deeplabcut/core/conversion_table.py
@@ -0,0 +1,79 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Defines conversion tables mapping DeepLabCut project bodyparts to SA bodyparts"""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+
+
+@dataclass
+class ConversionTable:
+ """Maps DLC project bodyparts to the corresponding SuperAnimal bodyparts
+
+ The conversion table must satisfy the following conditions (checked by validate):
+ - All SuperAnimal bodyparts must be valid (defined for the SuperAnimal model)
+ - All project bodyparts must be valid (defined for the DLC project)
+ """
+
+ super_animal: str
+ project_bodyparts: list[str]
+ super_animal_bodyparts: list[str]
+ table: dict[str, str]
+
+ def __post_init__(self):
+ """Validates the table"""
+ self.validate()
+
+ def to_array(self) -> np.ndarray:
+ """
+ Returns:
+ An array mapping the indices of SuperAnimal bodyparts
+
+ Raises:
+ ValueError: If the conversion table is misconfigured.
+ """
+ self.validate()
+ sa_indices = {sa_bpt: i for i, sa_bpt in enumerate(self.super_animal_bodyparts)}
+ sa_bpt_ordering = [self.table[bpt] for bpt in self.converted_bodyparts()]
+ return np.array([sa_indices[sa_bpt] for sa_bpt in sa_bpt_ordering])
+
+ def converted_bodyparts(self) -> list[str]:
+ """Returns: The project bodyparts included in this ordered"""
+ return [bpt for bpt in self.project_bodyparts if bpt in self.table]
+
+ def validate(self) -> None:
+ """
+ Raises:
+ ValueError: If the conversion table is misconfigured.
+ """
+ project_bpts = set(self.project_bodyparts)
+ sa_bpts = set(self.super_animal_bodyparts)
+
+ mapped_sa = set(self.table.values())
+ mapped_project = set(self.table.keys())
+
+ # check all mapped SuperAnimal bodyparts are in the config
+ if len(mapped_sa.difference(sa_bpts)) != 0:
+ extra_bodyparts = set(mapped_sa).difference(sa_bpts)
+ raise ValueError(
+ f"Some bodyparts in your mapping are not in the {self.super_animal} "
+ f"model: {extra_bodyparts}. Available bodyparts are {' '.join(sa_bpts)}"
+ )
+
+ # check all given bodyparts are in the project configuration
+ if len(mapped_project.difference(project_bpts)) != 0:
+ extra_bodyparts = mapped_project.difference(project_bpts)
+ raise ValueError(
+ "Some bodyparts in your mapping are not in your project configuration: "
+ f"{extra_bodyparts}. Defined bodyparts are {' '.join(project_bpts)}"
+ )
diff --git a/deeplabcut/core/crossvalutils.py b/deeplabcut/core/crossvalutils.py
new file mode 100644
index 0000000000..e95b2c7591
--- /dev/null
+++ b/deeplabcut/core/crossvalutils.py
@@ -0,0 +1,484 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+
+import os
+import pickle
+import shutil
+from collections import defaultdict
+from copy import deepcopy
+
+import networkx as nx
+import numpy as np
+import pandas as pd
+from scipy.spatial import cKDTree
+from sklearn.metrics.cluster import contingency_matrix
+from tqdm import tqdm
+
+from deeplabcut.core.inferenceutils import (
+ _parse_ground_truth_data,
+ Assembler,
+ evaluate_assembly,
+)
+from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions
+
+
+def _set_up_evaluation(data):
+ params = dict()
+ params["joint_names"] = data["metadata"]["all_joints_names"]
+ params["num_joints"] = len(params["joint_names"])
+ partaffinityfield_graph = data["metadata"]["PAFgraph"]
+ params["paf"] = np.arange(len(partaffinityfield_graph))
+ params["paf_graph"] = params["paf_links"] = [
+ partaffinityfield_graph[l] for l in params["paf"]
+ ]
+ params["bpts"] = params["ibpts"] = range(params["num_joints"])
+ params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
+ return params
+
+
+def _form_original_path(path):
+ root, filename = os.path.split(path)
+ base, ext = os.path.splitext(filename)
+ return os.path.join(root, filename.split("c")[0] + ext)
+
+
+def _unsorted_unique(array):
+ _, inds = np.unique(array, return_index=True)
+ return np.asarray(array)[np.sort(inds)]
+
+
+def find_closest_neighbors(
+ query: np.ndarray, ref: np.ndarray, k: int = 3
+) -> np.ndarray:
+ """Greedy matching of predicted keypoints to ground truth keypoints
+
+ Args:
+ query: the query keypoints
+ ref: the reference keypoints
+ k: The list of k-th nearest neighbors to return.
+
+ Returns:
+ an array of shape (len(query), ) containing the index of the closest
+ reference keypoint for each query keypoint
+ """
+ n_preds = ref.shape[0]
+ tree = cKDTree(ref)
+ dist, inds = tree.query(query, k=k)
+ idx = np.argsort(dist[:, 0])
+ neighbors = np.full(len(query), -1, dtype=int)
+ picked = {tree.n}
+ for i, ind in enumerate(inds[idx]):
+ for j in ind:
+ if j not in picked:
+ picked.add(j)
+ neighbors[idx[i]] = j
+ break
+ if len(picked) == (n_preds + 1):
+ break
+ return neighbors
+
+
+def _calc_separability(
+ vals_left, vals_right, n_bins=101, metric="jeffries", max_sensitivity=False
+):
+ if metric not in ("jeffries", "auc"):
+ raise ValueError("`metric` should be either 'jeffries' or 'auc'.")
+
+ bins = np.linspace(0, 1, n_bins)
+ hist_left = np.histogram(vals_left, bins=bins)[0]
+ hist_left = hist_left / hist_left.sum()
+ hist_right = np.histogram(vals_right, bins=bins)[0]
+ hist_right = hist_right / hist_right.sum()
+ tpr = np.cumsum(hist_right)
+ if metric == "jeffries":
+ sep = np.sqrt(
+ 2 * (1 - np.sum(np.sqrt(hist_left * hist_right)))
+ ) # Jeffries-Matusita distance
+ else:
+ sep = np.trapz(np.cumsum(hist_left), tpr)
+ if max_sensitivity:
+ threshold = bins[max(1, np.argmax(tpr > 0))]
+ else:
+ threshold = bins[np.argmin(1 - np.cumsum(hist_left) + tpr)]
+ return sep, threshold
+
+
+def _calc_within_between_pafs(
+ data,
+ metadata,
+ per_edge=True,
+ train_set_only=True,
+):
+ data = deepcopy(data)
+ train_inds = set(metadata["data"]["trainIndices"])
+ graph = data["metadata"]["PAFgraph"]
+ within_train = defaultdict(list)
+ within_test = defaultdict(list)
+ between_train = defaultdict(list)
+ between_test = defaultdict(list)
+ for i, (key, dict_) in enumerate(data.items()):
+ if key == "metadata":
+ continue
+
+ is_train = i in train_inds
+ if train_set_only and not is_train:
+ continue
+
+ df = dict_["groundtruth"][2]
+ try:
+ df.drop("single", level="individuals", inplace=True)
+ except KeyError:
+ pass
+ bpts = df.index.get_level_values("bodyparts").unique().to_list()
+ coords_gt = (
+ df.unstack(["individuals", "coords"])
+ .reindex(bpts, level="bodyparts")
+ .to_numpy()
+ .reshape((len(bpts), -1, 2))
+ )
+ if np.isnan(coords_gt).all():
+ continue
+
+ coords = dict_["prediction"]["coordinates"][0]
+ # Get animal IDs and corresponding indices in the arrays of detections
+ lookup = dict()
+ for i, (coord, coord_gt) in enumerate(zip(coords, coords_gt)):
+ inds = np.flatnonzero(np.all(~np.isnan(coord), axis=1))
+ inds_gt = np.flatnonzero(np.all(~np.isnan(coord_gt), axis=1))
+ if inds.size and inds_gt.size:
+ neighbors = find_closest_neighbors(coord_gt[inds_gt], coord[inds], k=3)
+ found = neighbors != -1
+ lookup[i] = dict(zip(inds_gt[found], inds[neighbors[found]]))
+
+ costs = dict_["prediction"]["costs"]
+ for k, v in costs.items():
+ paf = v["m1"]
+ mask_within = np.zeros(paf.shape, dtype=bool)
+ s, t = graph[k]
+ if s not in lookup or t not in lookup:
+ continue
+ lu_s = lookup[s]
+ lu_t = lookup[t]
+ common_id = set(lu_s).intersection(lu_t)
+ for id_ in common_id:
+ mask_within[lu_s[id_], lu_t[id_]] = True
+ within_vals = paf[mask_within]
+ between_vals = paf[~mask_within]
+ if is_train:
+ within_train[k].extend(within_vals)
+ between_train[k].extend(between_vals)
+ else:
+ within_test[k].extend(within_vals)
+ between_test[k].extend(between_vals)
+ if not per_edge:
+ within_train = np.concatenate([*within_train.values()])
+ within_test = np.concatenate([*within_test.values()])
+ between_train = np.concatenate([*between_train.values()])
+ between_test = np.concatenate([*between_test.values()])
+ return (within_train, within_test), (between_train, between_test)
+
+
+def _benchmark_paf_graphs(
+ config,
+ inference_cfg,
+ data,
+ paf_inds,
+ greedy=False,
+ add_discarded=True,
+ identity_only=False,
+ calibration_file="",
+ oks_sigma=0.1,
+ margin=0,
+ symmetric_kpts=None,
+ split_inds=None,
+):
+ metadata = data.pop("metadata")
+ multi_bpts_orig = auxfun_multianimal.extractindividualsandbodyparts(config)[2]
+ multi_bpts = [j for j in metadata["all_joints_names"] if j in multi_bpts_orig]
+ n_multi = len(multi_bpts)
+ data_ = {"metadata": metadata}
+ for k, v in data.items():
+ data_[k] = v["prediction"]
+ ass = Assembler(
+ data_,
+ max_n_individuals=inference_cfg["topktoretain"],
+ n_multibodyparts=n_multi,
+ greedy=greedy,
+ pcutoff=inference_cfg.get("pcutoff", 0.1),
+ min_affinity=inference_cfg.get("pafthreshold", 0.1),
+ add_discarded=add_discarded,
+ identity_only=identity_only,
+ )
+ if calibration_file:
+ ass.calibrate(calibration_file)
+
+ params = ass.metadata
+ image_paths = params["imnames"]
+ bodyparts = params["joint_names"]
+ idx = (
+ data[image_paths[0]]["groundtruth"][2]
+ .unstack("coords")
+ .reindex(bodyparts, level="bodyparts")
+ .index
+ )
+ mask_multi = idx.get_level_values("individuals") != "single"
+ if not mask_multi.all():
+ idx = idx.drop("single", level="individuals")
+ individuals = idx.get_level_values("individuals").unique()
+ n_individuals = len(individuals)
+ map_ = dict(zip(individuals, range(n_individuals)))
+
+ # Form ground truth beforehand
+ ground_truth = []
+ for i, imname in enumerate(image_paths):
+ temp = data[imname]["groundtruth"][2].reindex(multi_bpts, level="bodyparts")
+ ground_truth.append(temp.to_numpy().reshape((-1, 2)))
+ ground_truth = np.stack(ground_truth)
+ temp = np.ones((*ground_truth.shape[:2], 3))
+ temp[..., :2] = ground_truth
+ temp = temp.reshape((temp.shape[0], n_individuals, -1, 3))
+ ass_true_dict = _parse_ground_truth_data(temp)
+ ids = np.vectorize(map_.get)(idx.get_level_values("individuals").to_numpy())
+ ground_truth = np.insert(ground_truth, 2, ids, axis=2)
+
+ # Assemble animals on the full set of detections
+ paf_inds = sorted(paf_inds, key=len)
+ n_graphs = len(paf_inds)
+ all_scores = []
+ all_metrics = []
+ all_assemblies = []
+ for j, paf in enumerate(paf_inds, start=1):
+ print(f"Graph {j}|{n_graphs}")
+ ass.paf_inds = paf
+ ass.assemble()
+ all_assemblies.append((ass.assemblies, ass.unique, ass.metadata["imnames"]))
+ if split_inds is not None:
+ oks = []
+
+ # get the indices of the images in the training set
+ dataset_idx = [data[image_name]["index"] for image_name in image_paths]
+ for inds in split_inds:
+ ass_gt = {
+ k: v for k, v in ass_true_dict.items() if dataset_idx[k] in inds
+ }
+ ass_pred = {
+ k: v for k, v in ass.assemblies.items() if dataset_idx[k] in inds
+ }
+
+ oks.append(
+ evaluate_assembly(
+ ass_pred,
+ ass_gt,
+ oks_sigma,
+ margin=margin,
+ symmetric_kpts=symmetric_kpts,
+ greedy_matching=inference_cfg.get("greedy_oks", False),
+ )
+ )
+ else:
+ oks = evaluate_assembly(
+ ass.assemblies,
+ ass_true_dict,
+ oks_sigma,
+ margin=margin,
+ symmetric_kpts=symmetric_kpts,
+ greedy_matching=inference_cfg.get("greedy_oks", False),
+ )
+ all_metrics.append(oks)
+ scores = np.full((len(image_paths), 2), np.nan)
+ for i, imname in enumerate(tqdm(image_paths)):
+ gt = ground_truth[i]
+ gt = gt[~np.isnan(gt).any(axis=1)]
+ if len(np.unique(gt[:, 2])) < 2: # Only consider frames with 2+ animals
+ continue
+
+ # Count the number of unassembled bodyparts
+ n_dets = len(gt)
+ animals = ass.assemblies.get(i)
+ if animals is None:
+ if n_dets:
+ scores[i, 0] = 1
+ else:
+ animals = [
+ np.c_[animal.data, np.ones(animal.data.shape[0]) * n]
+ for n, animal in enumerate(animals)
+ ]
+ hyp = np.concatenate(animals)
+ hyp = hyp[~np.isnan(hyp).any(axis=1)]
+ scores[i, 0] = max(0, (n_dets - hyp.shape[0]) / n_dets)
+ neighbors = find_closest_neighbors(gt[:, :2], hyp[:, :2])
+ valid = neighbors != -1
+ id_gt = gt[valid, 2]
+ id_hyp = hyp[neighbors[valid], -1]
+ mat = contingency_matrix(id_gt, id_hyp)
+ purity = mat.max(axis=0).sum() / mat.sum()
+ scores[i, 1] = purity
+ all_scores.append((scores, paf))
+
+ dfs = []
+ for score, inds in all_scores:
+ df = pd.DataFrame(score, columns=["miss", "purity"])
+ df["ngraph"] = len(inds)
+ dfs.append(df)
+ big_df = pd.concat(dfs)
+ group = big_df.groupby("ngraph")
+ return (all_scores, group.agg(["mean", "std"]).T, all_metrics, all_assemblies)
+
+
+def _get_n_best_paf_graphs(
+ data,
+ metadata,
+ full_graph,
+ n_graphs=10,
+ root=None,
+ which="best",
+ ignore_inds=None,
+ metric="auc",
+):
+ if which not in ("best", "worst"):
+ raise ValueError('`which` must be either "best" or "worst"')
+
+ (within_train, _), (between_train, _) = _calc_within_between_pafs(
+ data,
+ metadata,
+ train_set_only=True,
+ )
+ # Handle unlabeled bodyparts...
+ existing_edges = set(k for k, v in within_train.items() if v)
+ if ignore_inds is not None:
+ existing_edges = existing_edges.difference(ignore_inds)
+ existing_edges = list(existing_edges)
+
+ if not any(between_train.values()):
+ # Only 1 animal, let us return the full graph indices only
+ return ([existing_edges], dict(zip(existing_edges, [0] * len(existing_edges))))
+
+ scores, _ = zip(
+ *[
+ _calc_separability(between_train[n], within_train[n], metric=metric)
+ for n in existing_edges
+ ]
+ )
+
+ # Find minimal skeleton
+ G = nx.Graph()
+ for edge, score in zip(existing_edges, scores):
+ if np.isfinite(score):
+ G.add_edge(*full_graph[edge], weight=score)
+ if which == "best":
+ order = np.asarray(existing_edges)[np.argsort(scores)[::-1]]
+ if root is None:
+ root = []
+ for edge in nx.maximum_spanning_edges(G, data=False):
+ root.append(full_graph.index(sorted(edge)))
+ else:
+ order = np.asarray(existing_edges)[np.argsort(scores)]
+ if root is None:
+ root = []
+ for edge in nx.minimum_spanning_edges(G, data=False):
+ root.append(full_graph.index(sorted(edge)))
+
+ n_edges = len(existing_edges) - len(root)
+ lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:]
+ order = order[np.isin(order, root, invert=True)]
+ paf_inds = [root]
+ for length in lengths:
+ paf_inds.append(root + list(order[:length]))
+ return paf_inds, dict(zip(existing_edges, scores))
+
+
+def cross_validate_paf_graphs(
+ config,
+ inference_config,
+ full_data_file,
+ metadata_file,
+ output_name="",
+ pcutoff=0.1,
+ oks_sigma=0.1,
+ margin=0,
+ greedy=False,
+ add_discarded=True,
+ calibrate=False,
+ overwrite_config=True,
+ n_graphs=10,
+ paf_inds=None,
+ symmetric_kpts=None,
+):
+ cfg = auxiliaryfunctions.read_config(config)
+ inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config)
+ inf_cfg_temp = inf_cfg.copy()
+ inf_cfg_temp["pcutoff"] = pcutoff
+
+ with open(full_data_file, "rb") as file:
+ data = pickle.load(file)
+ with open(metadata_file, "rb") as file:
+ metadata = pickle.load(file)
+
+ params = _set_up_evaluation(data)
+ to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
+ cfg, params["paf_graph"]
+ )
+ best_graphs = _get_n_best_paf_graphs(
+ data,
+ metadata,
+ params["paf_graph"],
+ ignore_inds=to_ignore,
+ n_graphs=n_graphs,
+ )
+ paf_scores = best_graphs[1]
+ if paf_inds is None:
+ paf_inds = best_graphs[0]
+
+ if calibrate:
+ trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
+ calibration_file = os.path.join(
+ cfg["project_path"],
+ str(trainingsetfolder),
+ "CollectedData_" + cfg["scorer"] + ".h5",
+ )
+ else:
+ calibration_file = ""
+
+ results = _benchmark_paf_graphs(
+ cfg,
+ inf_cfg_temp,
+ data,
+ paf_inds,
+ greedy,
+ add_discarded,
+ oks_sigma=oks_sigma,
+ margin=margin,
+ symmetric_kpts=symmetric_kpts,
+ calibration_file=calibration_file,
+ split_inds=[
+ metadata["data"]["trainIndices"],
+ metadata["data"]["testIndices"],
+ ],
+ )
+ # Select optimal PAF graph
+ df = results[1]
+ size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
+ pose_config = inference_config.replace("inference_cfg", "pose_cfg")
+ if not overwrite_config:
+ shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
+ inds = list(paf_inds[size_opt])
+ auxiliaryfunctions.edit_config(
+ pose_config, {"paf_best": [int(ind) for ind in inds]}
+ )
+ if output_name:
+ with open(output_name, "wb") as file:
+ pickle.dump([results], file)
+ return results[:3], paf_scores, results[3][size_opt]
+
+
+# Backwards compatibility
+_find_closest_neighbors = find_closest_neighbors
diff --git a/deeplabcut/core/engine.py b/deeplabcut/core/engine.py
new file mode 100644
index 0000000000..c6f07ca69d
--- /dev/null
+++ b/deeplabcut/core/engine.py
@@ -0,0 +1,49 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Defines the deep learning frameworks available"""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from enum import Enum
+
+
+@dataclass(frozen=True)
+class EngineDataMixin:
+ aliases: tuple[str]
+ model_folder_name: str
+ pose_cfg_name: str
+ results_folder_name: str
+
+
+class Engine(EngineDataMixin, Enum):
+ PYTORCH = (
+ ("pytorch", "torch"),
+ "dlc-models-pytorch",
+ "pytorch_config.yaml",
+ "evaluation-results-pytorch",
+ )
+ TF = (
+ ("tensorflow", "tf"),
+ "dlc-models",
+ "pose_cfg.yaml",
+ "evaluation-results",
+ )
+
+ @classmethod
+ def _missing_(cls, value):
+ if isinstance(value, str):
+ for member in cls:
+ if value.lower() in member.aliases:
+ return member
+ return None
+
+ def __repr__(self) -> str:
+ return f"Engine.{self.name}"
diff --git a/deeplabcut/core/inferenceutils.py b/deeplabcut/core/inferenceutils.py
new file mode 100644
index 0000000000..cbd21a877a
--- /dev/null
+++ b/deeplabcut/core/inferenceutils.py
@@ -0,0 +1,1314 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import heapq
+import itertools
+import multiprocessing
+import operator
+import pickle
+import warnings
+from collections import defaultdict
+from dataclasses import dataclass
+from math import erf, sqrt
+from typing import Any, Iterable, Tuple
+
+import networkx as nx
+import numpy as np
+import pandas as pd
+from scipy.optimize import linear_sum_assignment
+from scipy.spatial import cKDTree
+from scipy.spatial.distance import cdist, pdist
+from scipy.special import softmax
+from scipy.stats import chi2, gaussian_kde
+from tqdm import tqdm
+
+
+def _conv_square_to_condensed_indices(ind_row, ind_col, n):
+ if ind_row == ind_col:
+ raise ValueError("There are no diagonal elements in condensed matrices.")
+
+ if ind_row < ind_col:
+ ind_row, ind_col = ind_col, ind_row
+ return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col
+
+
+Position = Tuple[float, float]
+
+
+@dataclass(frozen=True)
+class Joint:
+ pos: Position
+ confidence: float = 1.0
+ label: int = None
+ idx: int = None
+ group: int = -1
+
+
+class Link:
+ def __init__(self, j1, j2, affinity=1):
+ self.j1 = j1
+ self.j2 = j2
+ self.affinity = affinity
+ self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2)
+
+ def __repr__(self):
+ return (
+ f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
+ )
+
+ @property
+ def confidence(self):
+ return self.j1.confidence * self.j2.confidence
+
+ @property
+ def idx(self):
+ return self.j1.idx, self.j2.idx
+
+ @property
+ def length(self):
+ return self._length
+
+ @length.setter
+ def length(self, length):
+ self._length = length
+
+ def to_vector(self):
+ return [*self.j1.pos, *self.j2.pos]
+
+
+class Assembly:
+ def __init__(self, size):
+ self.data = np.full((size, 4), np.nan)
+ self.confidence = 0 # 0 by default, overwritten otherwise with `add_joint`
+ self._affinity = 0
+ self._links = []
+ self._visible = set()
+ self._idx = set()
+ self._dict = dict()
+
+ def __len__(self):
+ return len(self._visible)
+
+ def __contains__(self, assembly):
+ return bool(self._visible.intersection(assembly._visible))
+
+ def __add__(self, other):
+ if other in self:
+ raise ValueError("Assemblies contain shared joints.")
+
+ assembly = Assembly(self.data.shape[0])
+ for link in self._links + other._links:
+ assembly.add_link(link)
+ return assembly
+
+ @classmethod
+ def from_array(cls, array):
+ n_bpts, n_cols = array.shape
+
+ # if a single coordinate is NaN for a bodypart, set all to NaN
+ array[np.isnan(array).any(axis=-1)] = np.nan
+
+ ass = cls(size=n_bpts)
+ ass.data[:, :n_cols] = array
+ visible = np.flatnonzero(~np.isnan(array).any(axis=1))
+ if n_cols < 3: # Only xy coordinates are being set
+ ass.data[visible, 2] = 1 # Set detection confidence to 1
+ ass._visible.update(visible)
+ return ass
+
+ @property
+ def xy(self):
+ return self.data[:, :2]
+
+ @property
+ def extent(self):
+ bbox = np.empty(4)
+ bbox[:2] = np.nanmin(self.xy, axis=0)
+ bbox[2:] = np.nanmax(self.xy, axis=0)
+ return bbox
+
+ @property
+ def area(self):
+ x1, y1, x2, y2 = self.extent
+ return (x2 - x1) * (y2 - y1)
+
+ @property
+ def confidence(self):
+ return np.nanmean(self.data[:, 2])
+
+ @confidence.setter
+ def confidence(self, confidence):
+ self.data[:, 2] = confidence
+
+ @property
+ def soft_identity(self):
+ data = self.data[~np.isnan(self.data).any(axis=1)]
+ unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True)
+ avg = np.bincount(idx, weights=data[:, 2]) / cnt
+ soft = softmax(avg)
+ return dict(zip(unq.astype(int), soft))
+
+ @property
+ def affinity(self):
+ n_links = self.n_links
+ if not n_links:
+ return 0
+ return self._affinity / n_links
+
+ @property
+ def n_links(self):
+ return len(self._links)
+
+ def intersection_with(self, other):
+ x11, y11, x21, y21 = self.extent
+ x12, y12, x22, y22 = other.extent
+ x1 = max(x11, x12)
+ y1 = max(y11, y12)
+ x2 = min(x21, x22)
+ y2 = min(y21, y22)
+ if x2 < x1 or y2 < y1:
+ return 0
+ ll = np.array([x1, y1])
+ ur = np.array([x2, y2])
+ xy1 = self.xy[~np.isnan(self.xy).any(axis=1)]
+ xy2 = other.xy[~np.isnan(other.xy).any(axis=1)]
+ in1 = np.all((xy1 >= ll) & (xy1 <= ur), axis=1).sum()
+ in2 = np.all((xy2 >= ll) & (xy2 <= ur), axis=1).sum()
+ return min(in1 / len(self), in2 / len(other))
+
+ def add_joint(self, joint):
+ if joint.label in self._visible or joint.label is None:
+ return False
+ self.data[joint.label] = *joint.pos, joint.confidence, joint.group
+ self._visible.add(joint.label)
+ self._idx.add(joint.idx)
+ return True
+
+ def remove_joint(self, joint):
+ if joint.label not in self._visible:
+ return False
+ self.data[joint.label] = np.nan
+ self._visible.remove(joint.label)
+ self._idx.remove(joint.idx)
+ return True
+
+ def add_link(self, link, store_dict=False):
+ if store_dict:
+ # Selective copy; deepcopy is >5x slower
+ self._dict = {
+ "data": self.data.copy(),
+ "_affinity": self._affinity,
+ "_links": self._links.copy(),
+ "_visible": self._visible.copy(),
+ "_idx": self._idx.copy(),
+ }
+ i1, i2 = link.idx
+ if i1 in self._idx and i2 in self._idx:
+ self._affinity += link.affinity
+ self._links.append(link)
+ return False
+ if link.j1.label in self._visible and link.j2.label in self._visible:
+ return False
+ self.add_joint(link.j1)
+ self.add_joint(link.j2)
+ self._affinity += link.affinity
+ self._links.append(link)
+ return True
+
+ def calc_pairwise_distances(self):
+ return pdist(self.xy, metric="sqeuclidean")
+
+
+class Assembler:
+ def __init__(
+ self,
+ data,
+ *,
+ max_n_individuals,
+ n_multibodyparts,
+ graph=None,
+ paf_inds=None,
+ greedy=False,
+ pcutoff=0.1,
+ min_affinity=0.05,
+ min_n_links=2,
+ max_overlap=0.8,
+ identity_only=False,
+ nan_policy="little",
+ force_fusion=False,
+ add_discarded=False,
+ window_size=0,
+ method="m1",
+ ):
+ self.data = data
+ self.metadata = self.parse_metadata(self.data)
+ self.max_n_individuals = max_n_individuals
+ self.n_multibodyparts = n_multibodyparts
+ self.n_uniquebodyparts = self.n_keypoints - n_multibodyparts
+ self.greedy = greedy
+ self.pcutoff = pcutoff
+ self.min_affinity = min_affinity
+ self.min_n_links = min_n_links
+ self.max_overlap = max_overlap
+ self._has_identity = "identity" in self[0]
+ if identity_only and not self._has_identity:
+ warnings.warn(
+ "The network was not trained with identity; setting `identity_only` to False."
+ )
+ self.identity_only = identity_only & self._has_identity
+ self.nan_policy = nan_policy
+ self.force_fusion = force_fusion
+ self.add_discarded = add_discarded
+ self.window_size = window_size
+ self.method = method
+ self.graph = graph or self.metadata["paf_graph"]
+ self.paf_inds = paf_inds or self.metadata["paf"]
+ self._gamma = 0.01
+ self._trees = dict()
+ self.safe_edge = False
+ self._kde = None
+ self.assemblies = dict()
+ self.unique = dict()
+
+ def __getitem__(self, item):
+ return self.data[self.metadata["imnames"][item]]
+
+ @classmethod
+ def empty(
+ cls,
+ max_n_individuals,
+ n_multibodyparts,
+ n_uniquebodyparts,
+ graph,
+ paf_inds,
+ greedy=False,
+ pcutoff=0.1,
+ min_affinity=0.05,
+ min_n_links=2,
+ max_overlap=0.8,
+ identity_only=False,
+ nan_policy="little",
+ force_fusion=False,
+ add_discarded=False,
+ window_size=0,
+ method="m1",
+ ):
+ # Dummy data
+ n_bodyparts = n_multibodyparts + n_uniquebodyparts
+ data = {
+ "metadata": {
+ "all_joints_names": ["" for _ in range(n_bodyparts)],
+ "PAFgraph": graph,
+ "PAFinds": paf_inds,
+ },
+ "0": {},
+ }
+ return cls(
+ data,
+ max_n_individuals=max_n_individuals,
+ n_multibodyparts=n_multibodyparts,
+ graph=graph,
+ paf_inds=paf_inds,
+ greedy=greedy,
+ pcutoff=pcutoff,
+ min_affinity=min_affinity,
+ min_n_links=min_n_links,
+ max_overlap=max_overlap,
+ identity_only=identity_only,
+ nan_policy=nan_policy,
+ force_fusion=force_fusion,
+ add_discarded=add_discarded,
+ window_size=window_size,
+ method=method,
+ )
+
+ @property
+ def n_keypoints(self):
+ return self.metadata["num_joints"]
+
+ def calibrate(self, train_data_file):
+ df = pd.read_hdf(train_data_file)
+ try:
+ df.drop("single", level="individuals", axis=1, inplace=True)
+ except KeyError:
+ pass
+ n_bpts = len(df.columns.get_level_values("bodyparts").unique())
+ if n_bpts == 1:
+ warnings.warn("There is only one keypoint; skipping calibration...")
+ return
+
+ xy = df.to_numpy().reshape((-1, n_bpts, 2))
+ frac_valid = np.mean(~np.isnan(xy), axis=(1, 2))
+ # Only keeps skeletons that are more than 90% complete
+ xy = xy[frac_valid >= 0.9]
+ if not xy.size:
+ warnings.warn("No complete poses were found. Skipping calibration...")
+ return
+
+ # TODO Normalize dists by longest length?
+ # TODO Smarter imputation technique (Bayesian? Grassmann averages?)
+ dists = np.vstack([pdist(data, "sqeuclidean") for data in xy])
+ mu = np.nanmean(dists, axis=0)
+ missing = np.isnan(dists)
+ dists = np.where(missing, mu, dists)
+ try:
+ kde = gaussian_kde(dists.T)
+ kde.mean = mu
+ self._kde = kde
+ self.safe_edge = True
+ except np.linalg.LinAlgError:
+ # Covariance matrix estimation fails due to numerical singularities
+ warnings.warn(
+ "The assembler could not be robustly calibrated. Continuing without it..."
+ )
+
+ def calc_assembly_mahalanobis_dist(
+ self, assembly, return_proba=False, nan_policy="little"
+ ):
+ if self._kde is None:
+ raise ValueError("Assembler should be calibrated first with training data.")
+
+ dists = assembly.calc_pairwise_distances() - self._kde.mean
+ mask = np.isnan(dists)
+ # Distance is undefined if the assembly is empty
+ if not len(assembly) or mask.all():
+ if return_proba:
+ return np.inf, 0
+ return np.inf
+
+ if nan_policy == "little":
+ inds = np.flatnonzero(~mask)
+ dists = dists[inds]
+ inv_cov = self._kde.inv_cov[np.ix_(inds, inds)]
+ # Correct distance to account for missing observations
+ factor = self._kde.d / len(inds)
+ else:
+ # Alternatively, reduce contribution of missing values to the Mahalanobis
+ # distance to zero by substituting the corresponding means.
+ dists[mask] = 0
+ mask.fill(False)
+ inv_cov = self._kde.inv_cov
+ factor = 1
+ dot = dists @ inv_cov
+ mahal = factor * sqrt(np.sum((dot * dists), axis=-1))
+ if return_proba:
+ proba = 1 - chi2.cdf(mahal, np.sum(~mask))
+ return mahal, proba
+ return mahal
+
+ def calc_link_probability(self, link):
+ if self._kde is None:
+ raise ValueError("Assembler should be calibrated first with training data.")
+
+ i = link.j1.label
+ j = link.j2.label
+ ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts)
+ mu = self._kde.mean[ind]
+ sigma = self._kde.covariance[ind, ind]
+ z = (link.length**2 - mu) / sigma
+ return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2))))
+
+ @staticmethod
+ def _flatten_detections(data_dict):
+ ind = 0
+ coordinates = data_dict["coordinates"][0]
+ confidence = data_dict["confidence"]
+ ids = data_dict.get("identity", None)
+ if ids is None:
+ ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence]
+ else:
+ ids = [arr.argmax(axis=1) for arr in ids]
+ for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)):
+ if not np.any(coords):
+ continue
+ for xy, p, g in zip(coords, conf, id_):
+ joint = Joint(tuple(xy), p.item(), i, ind, g)
+ ind += 1
+ yield joint
+
+ def extract_best_links(self, joints_dict, costs, trees=None):
+ links = []
+ for ind in self.paf_inds:
+ s, t = self.graph[ind]
+ dets_s = joints_dict.get(s, None)
+ dets_t = joints_dict.get(t, None)
+ if dets_s is None or dets_t is None:
+ continue
+ if ind not in costs:
+ continue
+ lengths = costs[ind]["distance"]
+ if np.isinf(lengths).all():
+ continue
+ aff = costs[ind][self.method].copy()
+ aff[np.isnan(aff)] = 0
+
+ if trees:
+ vecs = np.vstack(
+ [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]
+ )
+ dists = []
+ for n, tree in enumerate(trees, start=1):
+ d, _ = tree.query(vecs)
+ dists.append(np.exp(-self._gamma * n * d))
+ w = np.mean(dists, axis=0)
+ aff *= w.reshape(aff.shape)
+
+ if self.greedy:
+ conf = np.asarray(
+ [
+ [det_s.confidence * det_t.confidence for det_t in dets_t]
+ for det_s in dets_s
+ ]
+ )
+ rows, cols = np.where(
+ (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)
+ )
+ candidates = sorted(
+ zip(rows, cols, aff[rows, cols], lengths[rows, cols]),
+ key=lambda x: x[2],
+ reverse=True,
+ )
+ i_seen = set()
+ j_seen = set()
+ for i, j, w, l in candidates:
+ if i not in i_seen and j not in j_seen:
+ i_seen.add(i)
+ j_seen.add(j)
+ links.append(Link(dets_s[i], dets_t[j], w))
+ if len(i_seen) == self.max_n_individuals:
+ break
+ else: # Optimal keypoint pairing
+ inds_s = sorted(
+ range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True
+ )[: self.max_n_individuals]
+ inds_t = sorted(
+ range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True
+ )[: self.max_n_individuals]
+ keep_s = [
+ ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff
+ ]
+ keep_t = [
+ ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff
+ ]
+ aff = aff[np.ix_(keep_s, keep_t)]
+ rows, cols = linear_sum_assignment(aff, maximize=True)
+ for row, col in zip(rows, cols):
+ w = aff[row, col]
+ if w >= self.min_affinity:
+ links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w))
+ return links
+
+ def _fill_assembly(self, assembly, lookup, assembled, safe_edge, nan_policy):
+ stack = []
+ visited = set()
+ tabu = []
+ counter = itertools.count()
+
+ def push_to_stack(i):
+ for j, link in lookup[i].items():
+ if j in assembly._idx:
+ continue
+ if link.idx in visited:
+ continue
+ heapq.heappush(stack, (-link.affinity, next(counter), link))
+ visited.add(link.idx)
+
+ for idx in assembly._idx:
+ push_to_stack(idx)
+
+ while stack and len(assembly) < self.n_multibodyparts:
+ _, _, best = heapq.heappop(stack)
+ i, j = best.idx
+ if i in assembly._idx:
+ new_ind = j
+ elif j in assembly._idx:
+ new_ind = i
+ else:
+ continue
+ if new_ind in assembled:
+ continue
+ if safe_edge:
+ d_old = self.calc_assembly_mahalanobis_dist(
+ assembly, nan_policy=nan_policy
+ )
+ success = assembly.add_link(best, store_dict=True)
+ if not success:
+ assembly._dict = dict()
+ continue
+ d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
+ if d < d_old:
+ push_to_stack(new_ind)
+ try:
+ _, _, link = heapq.heappop(tabu)
+ heapq.heappush(stack, (-link.affinity, next(counter), link))
+ except IndexError:
+ pass
+ else:
+ heapq.heappush(tabu, (d - d_old, next(counter), best))
+ assembly.__dict__.update(assembly._dict)
+ assembly._dict = dict()
+ else:
+ assembly.add_link(best)
+ push_to_stack(new_ind)
+
+ def build_assemblies(self, links):
+ lookup = defaultdict(dict)
+ for link in links:
+ i, j = link.idx
+ lookup[i][j] = link
+ lookup[j][i] = link
+
+ assemblies = []
+ assembled = set()
+
+ # Fill the subsets with unambiguous, complete individuals
+ G = nx.Graph([link.idx for link in links])
+ for chain in nx.connected_components(G):
+ if len(chain) == self.n_multibodyparts:
+ edges = [tuple(sorted(edge)) for edge in G.edges(chain)]
+ assembly = Assembly(self.n_multibodyparts)
+ for link in links:
+ i, j = link.idx
+ if (i, j) in edges:
+ success = assembly.add_link(link)
+ if success:
+ lookup[i].pop(j)
+ lookup[j].pop(i)
+ assembled.update(assembly._idx)
+ assemblies.append(assembly)
+
+ if len(assemblies) == self.max_n_individuals:
+ return assemblies, assembled
+
+ for link in sorted(links, key=lambda x: x.affinity, reverse=True):
+ if any(i in assembled for i in link.idx):
+ continue
+ assembly = Assembly(self.n_multibodyparts)
+ assembly.add_link(link)
+ self._fill_assembly(
+ assembly, lookup, assembled, self.safe_edge, self.nan_policy
+ )
+ for link in assembly._links:
+ i, j = link.idx
+ lookup[i].pop(j)
+ lookup[j].pop(i)
+ assembled.update(assembly._idx)
+ assemblies.append(assembly)
+
+ # Fuse superfluous assemblies
+ n_extra = len(assemblies) - self.max_n_individuals
+ if n_extra > 0:
+ if self.safe_edge:
+ ds_old = [
+ self.calc_assembly_mahalanobis_dist(assembly)
+ for assembly in assemblies
+ ]
+ while len(assemblies) > self.max_n_individuals:
+ ds = []
+ for i, j in itertools.combinations(range(len(assemblies)), 2):
+ if assemblies[j] not in assemblies[i]:
+ temp = assemblies[i] + assemblies[j]
+ d = self.calc_assembly_mahalanobis_dist(temp)
+ delta = d - max(ds_old[i], ds_old[j])
+ ds.append((i, j, delta, d, temp))
+ if not ds:
+ break
+ min_ = sorted(ds, key=lambda x: x[2])
+ i, j, delta, d, new = min_[0]
+ if delta < 0 or len(min_) == 1:
+ assemblies[i] = new
+ assemblies.pop(j)
+ ds_old[i] = d
+ ds_old.pop(j)
+ else:
+ break
+ elif self.force_fusion:
+ assemblies = sorted(assemblies, key=len)
+ for nrow in range(n_extra):
+ assembly = assemblies[nrow]
+ candidates = [a for a in assemblies[nrow:] if assembly not in a]
+ if not candidates:
+ continue
+ if len(candidates) == 1:
+ candidate = candidates[0]
+ else:
+ dists = []
+ for cand in candidates:
+ d = cdist(assembly.xy, cand.xy)
+ dists.append(np.nanmin(d))
+ candidate = candidates[np.argmin(dists)]
+ ind = assemblies.index(candidate)
+ assemblies[ind] += assembly
+ else:
+ store = dict()
+ for assembly in assemblies:
+ if len(assembly) != self.n_multibodyparts:
+ for i in assembly._idx:
+ store[i] = assembly
+ used = [link for assembly in assemblies for link in assembly._links]
+ unconnected = [link for link in links if link not in used]
+ for link in unconnected:
+ i, j = link.idx
+ try:
+ if store[j] not in store[i]:
+ temp = store[i] + store[j]
+ store[i].__dict__.update(temp.__dict__)
+ assemblies.remove(store[j])
+ for idx in store[j]._idx:
+ store[idx] = store[i]
+ except KeyError:
+ pass
+
+ # Second pass without edge safety
+ for assembly in assemblies:
+ if len(assembly) != self.n_multibodyparts:
+ self._fill_assembly(assembly, lookup, assembled, False, "")
+ assembled.update(assembly._idx)
+
+ return assemblies, assembled
+
+ def _assemble(self, data_dict, ind_frame):
+ joints = list(self._flatten_detections(data_dict))
+ if not joints:
+ return None, None
+
+ bag = defaultdict(list)
+ for joint in joints:
+ bag[joint.label].append(joint)
+
+ assembled = set()
+
+ if self.n_uniquebodyparts:
+ unique = np.full((self.n_uniquebodyparts, 3), np.nan)
+ for n, ind in enumerate(range(self.n_multibodyparts, self.n_keypoints)):
+ dets = bag[ind]
+ if not dets:
+ continue
+ if len(dets) > 1:
+ det = max(dets, key=lambda x: x.confidence)
+ else:
+ det = dets[0]
+ # Mark the unique body parts as assembled anyway so
+ # they are not used later on to fill assemblies.
+ assembled.update(d.idx for d in dets)
+ if det.confidence <= self.pcutoff and not self.add_discarded:
+ continue
+ unique[n] = *det.pos, det.confidence
+ if np.isnan(unique).all():
+ unique = None
+ else:
+ unique = None
+
+ if not any(i in bag for i in range(self.n_multibodyparts)):
+ return None, unique
+
+ if self.n_multibodyparts == 1:
+ assemblies = []
+ for joint in bag[0]:
+ if joint.confidence >= self.pcutoff:
+ ass = Assembly(self.n_multibodyparts)
+ ass.add_joint(joint)
+ assemblies.append(ass)
+ return assemblies, unique
+
+ if self.max_n_individuals == 1:
+ get_attr = operator.attrgetter("confidence")
+ ass = Assembly(self.n_multibodyparts)
+ for ind in range(self.n_multibodyparts):
+ joints = bag[ind]
+ if not joints:
+ continue
+ ass.add_joint(max(joints, key=get_attr))
+ return [ass], unique
+
+ if self.identity_only:
+ assemblies = []
+ get_attr = operator.attrgetter("group")
+ temp = sorted(
+ (joint for joint in joints if np.isfinite(joint.confidence)),
+ key=get_attr,
+ )
+ groups = itertools.groupby(temp, get_attr)
+ for _, group in groups:
+ ass = Assembly(self.n_multibodyparts)
+ for joint in sorted(group, key=lambda x: x.confidence, reverse=True):
+ if (
+ joint.confidence >= self.pcutoff
+ and joint.label < self.n_multibodyparts
+ ):
+ ass.add_joint(joint)
+ if len(ass):
+ assemblies.append(ass)
+ assembled.update(ass._idx)
+ else:
+ trees = []
+ for j in range(1, self.window_size + 1):
+ tree = self._trees.get(ind_frame - j, None)
+ if tree is not None:
+ trees.append(tree)
+
+ links = self.extract_best_links(bag, data_dict["costs"], trees)
+ if self._kde:
+ for link in links[::-1]:
+ p = max(self.calc_link_probability(link), 0.001)
+ link.affinity *= p
+ if link.affinity < self.min_affinity:
+ links.remove(link)
+
+ if self.window_size >= 1 and links:
+ # Store selected edges for subsequent frames
+ vecs = np.vstack([link.to_vector() for link in links])
+ self._trees[ind_frame] = cKDTree(vecs)
+
+ assemblies, assembled_ = self.build_assemblies(links)
+ assembled.update(assembled_)
+
+ # Remove invalid assemblies
+ discarded = set(
+ joint
+ for joint in joints
+ if joint.idx not in assembled and np.isfinite(joint.confidence)
+ )
+ for assembly in assemblies[::-1]:
+ if 0 < assembly.n_links < self.min_n_links or not len(assembly):
+ for link in assembly._links:
+ discarded.update((link.j1, link.j2))
+ assemblies.remove(assembly)
+ if 0 < self.max_overlap < 1: # Non-maximum pose suppression
+ if self._kde is not None:
+ scores = [
+ -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies
+ ]
+ else:
+ scores = [ass._affinity for ass in assemblies]
+ lst = list(zip(scores, assemblies))
+ assemblies = []
+ while lst:
+ temp = max(lst, key=lambda x: x[0])
+ lst.remove(temp)
+ assemblies.append(temp[1])
+ for pair in lst[::-1]:
+ if temp[1].intersection_with(pair[1]) >= self.max_overlap:
+ lst.remove(pair)
+ if len(assemblies) > self.max_n_individuals:
+ assemblies = sorted(assemblies, key=len, reverse=True)
+ for assembly in assemblies[self.max_n_individuals :]:
+ for link in assembly._links:
+ discarded.update((link.j1, link.j2))
+ assemblies = assemblies[: self.max_n_individuals]
+
+ if self.add_discarded and discarded:
+ # Fill assemblies with unconnected body parts
+ for joint in sorted(discarded, key=lambda x: x.confidence, reverse=True):
+ if self.safe_edge:
+ for assembly in assemblies:
+ if joint.label in assembly._visible:
+ continue
+ d_old = self.calc_assembly_mahalanobis_dist(assembly)
+ assembly.add_joint(joint)
+ d = self.calc_assembly_mahalanobis_dist(assembly)
+ if d < d_old:
+ break
+ assembly.remove_joint(joint)
+ else:
+ dists = []
+ for i, assembly in enumerate(assemblies):
+ if joint.label in assembly._visible:
+ continue
+ d = cdist(assembly.xy, np.atleast_2d(joint.pos))
+ dists.append((i, np.nanmin(d)))
+ if not dists:
+ continue
+ min_ = sorted(dists, key=lambda x: x[1])
+ ind, _ = min_[0]
+ assemblies[ind].add_joint(joint)
+
+ return assemblies, unique
+
+ def assemble(self, chunk_size=1, n_processes=None):
+ self.assemblies = dict()
+ self.unique = dict()
+ # Spawning (rather than forking) multiple processes does not
+ # work nicely with the GUI or interactive sessions.
+ # In that case, we fall back to the serial assembly.
+ if chunk_size == 0 or multiprocessing.get_start_method() == "spawn":
+
+ for i, data_dict in enumerate(tqdm(self)):
+ assemblies, unique = self._assemble(data_dict, i)
+ if assemblies:
+ self.assemblies[i] = assemblies
+ if unique is not None:
+ self.unique[i] = unique
+ else:
+ global wrapped # Hack to make the function pickable
+
+ def wrapped(i):
+ return i, self._assemble(self[i], i)
+
+ n_frames = len(self.metadata["imnames"])
+ with multiprocessing.Pool(n_processes) as p:
+ with tqdm(total=n_frames) as pbar:
+ for i, (assemblies, unique) in p.imap_unordered(
+ wrapped, range(n_frames), chunksize=chunk_size
+ ):
+ if assemblies:
+ self.assemblies[i] = assemblies
+ if unique is not None:
+ self.unique[i] = unique
+ pbar.update()
+
+ def from_pickle(self, pickle_path):
+ with open(pickle_path, "rb") as file:
+ data = pickle.load(file)
+ self.unique = data.pop("single", {})
+ self.assemblies = data
+
+ @staticmethod
+ def parse_metadata(data):
+ params = dict()
+ params["joint_names"] = data["metadata"]["all_joints_names"]
+ params["num_joints"] = len(params["joint_names"])
+ params["paf_graph"] = data["metadata"]["PAFgraph"]
+ params["paf"] = data["metadata"].get(
+ "PAFinds", np.arange(len(params["joint_names"]))
+ )
+ params["bpts"] = params["ibpts"] = range(params["num_joints"])
+ params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
+ return params
+
+ def to_h5(self, output_name):
+ data = np.full(
+ (
+ len(self.metadata["imnames"]),
+ self.max_n_individuals,
+ self.n_multibodyparts,
+ 4,
+ ),
+ fill_value=np.nan,
+ )
+ for ind, assemblies in self.assemblies.items():
+ for n, assembly in enumerate(assemblies):
+ data[ind, n] = assembly.data
+ index = pd.MultiIndex.from_product(
+ [
+ ["scorer"],
+ map(str, range(self.max_n_individuals)),
+ map(str, range(self.n_multibodyparts)),
+ ["x", "y", "likelihood"],
+ ],
+ names=["scorer", "individuals", "bodyparts", "coords"],
+ )
+ temp = data[..., :3].reshape((data.shape[0], -1))
+ df = pd.DataFrame(temp, columns=index)
+ df.to_hdf(output_name, key="ass")
+
+ def to_pickle(self, output_name):
+ data = dict()
+ for ind, assemblies in self.assemblies.items():
+ data[ind] = [ass.data for ass in assemblies]
+ if self.unique:
+ data["single"] = self.unique
+ with open(output_name, "wb") as file:
+ pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)
+
+
+@dataclass
+class MatchedPrediction:
+ """A match between a prediction and a ground truth assembly
+
+ The ground truth assembly should be None f the prediction was not matched to any GT,
+ and the OKS should be 0.
+
+ Attributes:
+ prediction: A prediction made by a pose model.
+ score: The confidence score for the prediction.
+ ground_truth: If None, then this prediction is not matched to any ground truth
+ (this can happen when there are more predicted individuals than GT).
+ Otherwise, the ground truth assembly to which this prediction is matched.
+ oks: The OKS score between the prediction and the ground truth pose.
+ """
+
+ prediction: Assembly
+ score: float
+ ground_truth: Assembly | None
+ oks: float
+
+
+def calc_object_keypoint_similarity(
+ xy_pred,
+ xy_true,
+ sigma,
+ margin=0,
+ symmetric_kpts=None,
+):
+ visible_gt = ~np.isnan(xy_true).all(axis=1)
+ if visible_gt.sum() < 2: # At least 2 points needed to calculate scale
+ return np.nan
+
+ true = xy_true[visible_gt]
+ scale_squared = np.prod(np.ptp(true, axis=0) + np.spacing(1) + margin * 2)
+ if np.isclose(scale_squared, 0):
+ return np.nan
+
+ k_squared = (2 * sigma) ** 2
+ denom = 2 * scale_squared * k_squared
+ if symmetric_kpts is None:
+ pred = xy_pred[visible_gt]
+ pred[np.isnan(pred)] = np.inf
+ dist_squared = np.sum((pred - true) ** 2, axis=1)
+ oks = np.exp(-dist_squared / denom)
+ return np.mean(oks)
+ else:
+ oks = []
+ xy_preds = [xy_pred]
+ combos = (
+ pair
+ for l in range(len(symmetric_kpts))
+ for pair in itertools.combinations(symmetric_kpts, l + 1)
+ )
+ for pairs in combos:
+ # Swap corresponding keypoints
+ tmp = xy_pred.copy()
+ for pair in pairs:
+ tmp[pair, :] = tmp[pair[::-1], :]
+ xy_preds.append(tmp)
+ for xy_pred in xy_preds:
+ pred = xy_pred[visible_gt]
+ pred[np.isnan(pred)] = np.inf
+ dist_squared = np.sum((pred - true) ** 2, axis=1)
+ oks.append(np.mean(np.exp(-dist_squared / denom)))
+ return max(oks)
+
+
+def match_assemblies(
+ predictions: list[Assembly],
+ ground_truth: list[Assembly],
+ sigma: float,
+ margin: int = 0,
+ symmetric_kpts: list[tuple[int, int]] | None = None,
+ greedy_matching: bool = False,
+ greedy_oks_threshold: float = 0.0,
+) -> tuple[int, list[MatchedPrediction]]:
+ """Matches assemblies to ground truth predictions
+
+ Returns:
+ int: the total number of valid ground truth assemblies
+ list[MatchedPrediction]: a list containing all valid predictions, potentially
+ matched to ground truth assemblies.
+ """
+ # Only consider assemblies of at least two keypoints
+ predictions = [a for a in predictions if len(a) > 1]
+ ground_truth = [a for a in ground_truth if len(a) > 1]
+ num_ground_truth = len(ground_truth)
+
+ # Sort predictions by score
+ inds_pred = np.argsort(
+ [ins.affinity if ins.n_links else ins.confidence for ins in predictions]
+ )[::-1]
+ predictions = np.asarray(predictions)[inds_pred]
+
+ # indices of unmatched ground truth assemblies
+ matched = [
+ MatchedPrediction(
+ prediction=p,
+ score=(p.affinity if p.n_links else p.confidence),
+ ground_truth=None,
+ oks=0.0,
+ )
+ for p in predictions
+ ]
+
+ # Greedy assembly matching like in pycocotools
+ if greedy_matching:
+ matched_gt_indices = set()
+ for idx, pred in enumerate(predictions):
+ oks = [
+ calc_object_keypoint_similarity(
+ pred.xy,
+ gt.xy,
+ sigma,
+ margin,
+ symmetric_kpts,
+ )
+ for gt in ground_truth
+ ]
+ if np.all(np.isnan(oks)):
+ continue
+
+ ind_best = np.nanargmax(oks)
+
+ # if this gt already matched, and not a crowd, continue
+ if ind_best in matched_gt_indices:
+ continue
+
+ # Only match the pred to the GT if the OKS value is above a given threshold
+ if oks[ind_best] < greedy_oks_threshold:
+ continue
+
+ matched_gt_indices.add(ind_best)
+ matched[idx].ground_truth = ground_truth[ind_best]
+ matched[idx].oks = oks[ind_best]
+
+ # Global rather than greedy assembly matching
+ else:
+ inds_true = list(range(len(ground_truth)))
+ mat = np.zeros((len(predictions), len(ground_truth)))
+ for i, a_pred in enumerate(predictions):
+ for j, a_true in enumerate(ground_truth):
+ oks = calc_object_keypoint_similarity(
+ a_pred.xy,
+ a_true.xy,
+ sigma,
+ margin,
+ symmetric_kpts,
+ )
+ if ~np.isnan(oks):
+ mat[i, j] = oks
+ rows, cols = linear_sum_assignment(mat, maximize=True)
+ for row, col in zip(rows, cols):
+ matched[row].ground_truth = ground_truth[col]
+ matched[row].oks = mat[row, col]
+ _ = inds_true.remove(col)
+
+ return num_ground_truth, matched
+
+
+def parse_ground_truth_data_file(h5_file):
+ df = pd.read_hdf(h5_file)
+ try:
+ df.drop("single", axis=1, level="individuals", inplace=True)
+ except KeyError:
+ pass
+ # Cast columns of dtype 'object' to float to avoid TypeError
+ # further down in _parse_ground_truth_data.
+ cols = df.select_dtypes(include="object").columns
+ if cols.to_list():
+ df[cols] = df[cols].astype("float")
+ n_individuals = len(df.columns.get_level_values("individuals").unique())
+ n_bodyparts = len(df.columns.get_level_values("bodyparts").unique())
+ data = df.to_numpy().reshape((df.shape[0], n_individuals, n_bodyparts, -1))
+ return _parse_ground_truth_data(data)
+
+
+def _parse_ground_truth_data(data):
+ gt = dict()
+ for i, arr in enumerate(data):
+ temp = []
+ for row in arr:
+ if np.isnan(row[:, :2]).all():
+ continue
+ ass = Assembly.from_array(row)
+ temp.append(ass)
+ if not temp:
+ continue
+ gt[i] = temp
+ return gt
+
+
+def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
+ if not hasattr(Assembly, criterion):
+ raise ValueError(f"Invalid criterion {criterion}.")
+
+ if len(qs) != 2:
+ raise ValueError(
+ "Two percentiles (for lower and upper bounds) should be given."
+ )
+
+ tuples = []
+ for frame_ind, assemblies in dict_of_assemblies.items():
+ for assembly in assemblies:
+ tuples.append((frame_ind, getattr(assembly, criterion)))
+ frame_inds, vals = zip(*tuples)
+ vals = np.asarray(vals)
+ lo, up = np.percentile(vals, qs, interpolation="nearest")
+ inds = np.flatnonzero((vals < lo) | (vals > up)).tolist()
+ return list(set(frame_inds[i] for i in inds))
+
+
+def _compute_precision_and_recall(
+ num_gt_assemblies: int,
+ oks_values: np.ndarray,
+ oks_threshold: float,
+ recall_thresholds: np.ndarray,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Computes the precision and recall scores at a given OKS threshold
+
+ Args:
+ num_gt_assemblies: the number of ground truth assemblies (used to compute false
+ negatives + true positives).
+ oks_values: the OKS value to the matched GT assembly for each prediction
+ oks_threshold: the OKS threshold at which recall and precision are being
+ computed
+ recall_thresholds: the recall thresholds to use to compute scores
+
+ Returns:
+ The precision and recall arrays at each recall threshold
+ """
+ tp = np.cumsum(oks_values >= oks_threshold)
+ fp = np.cumsum(oks_values < oks_threshold)
+ rc = tp / num_gt_assemblies
+ pr = tp / (fp + tp + np.spacing(1))
+ recall = rc[-1]
+
+ # Guarantee precision decreases monotonically, see
+ # https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173
+ for i in range(len(pr) - 1, 0, -1):
+ if pr[i] > pr[i - 1]:
+ pr[i - 1] = pr[i]
+
+ inds_rc = np.searchsorted(rc, recall_thresholds, side="left")
+ precision = np.zeros(inds_rc.shape)
+ valid = inds_rc < len(pr)
+ precision[valid] = pr[inds_rc[valid]]
+ return precision, recall
+
+
+def evaluate_assembly_greedy(
+ assemblies_gt: dict[Any, list[Assembly]],
+ assemblies_pred: dict[Any, list[Assembly]],
+ oks_sigma: float,
+ oks_thresholds: Iterable[float],
+ margin: int | float = 0,
+ symmetric_kpts: list[tuple[int, int]] | None = None,
+) -> dict:
+ """Runs greedy mAP evaluation, as done by pycocotools
+
+ Args:
+ assemblies_gt: A dictionary mapping image ID (e.g. filepath) to ground truth
+ assemblies. Should contain all the same keys as ``assemblies_pred``.
+ assemblies_pred: A dictionary mapping image ID (e.g. filepath) to predicted
+ assemblies. Should contain all the same keys as ``assemblies_gt``.
+ oks_sigma: The sigma to use to compute OKS values for keypoints .
+ oks_thresholds: The OKS thresholds at which to compute precision & recall.
+ margin: The margin to use to compute bounding boxes from keypoints.
+ symmetric_kpts: The symmetric keypoints in the dataset.
+ """
+ recall_thresholds = np.linspace( # np.linspace(0, 1, 101)
+ start=0.0, stop=1.00, num=int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True
+ )
+ precisions = []
+ recalls = []
+ for oks_t in oks_thresholds:
+ all_matched = []
+ total_gt_assemblies = 0
+ for ind, gt_assembly in assemblies_gt.items():
+ pred_assemblies = assemblies_pred.get(ind, [])
+ num_gt_assemblies, matched = match_assemblies(
+ pred_assemblies,
+ gt_assembly,
+ oks_sigma,
+ margin,
+ symmetric_kpts,
+ greedy_matching=True,
+ greedy_oks_threshold=oks_t,
+ )
+ all_matched.extend(matched)
+ total_gt_assemblies += num_gt_assemblies
+
+ if len(all_matched) == 0:
+ precisions.append(0.0)
+ recalls.append(0.0)
+ continue
+
+ # Global sort of assemblies (across all images) by score
+ scores = np.asarray([-m.score for m in all_matched])
+ sorted_pred_indices = np.argsort(scores, kind="mergesort")
+ oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices]
+
+ # Compute prediction and recall
+ p, r = _compute_precision_and_recall(
+ total_gt_assemblies, oks, oks_t, recall_thresholds
+ )
+ precisions.append(p)
+ recalls.append(r)
+
+ precisions = np.asarray(precisions)
+ recalls = np.asarray(recalls)
+ return {
+ "precisions": precisions,
+ "recalls": recalls,
+ "mAP": precisions.mean(),
+ "mAR": recalls.mean(),
+ }
+
+
+def evaluate_assembly(
+ ass_pred_dict,
+ ass_true_dict,
+ oks_sigma=0.072,
+ oks_thresholds=np.linspace(0.5, 0.95, 10),
+ margin=0,
+ symmetric_kpts=None,
+ greedy_matching=False,
+ with_tqdm: bool = True,
+):
+ if greedy_matching:
+ return evaluate_assembly_greedy(
+ ass_true_dict,
+ ass_pred_dict,
+ oks_sigma=oks_sigma,
+ oks_thresholds=oks_thresholds,
+ margin=margin,
+ symmetric_kpts=symmetric_kpts,
+ )
+
+ # sigma is taken as the median of all COCO keypoint standard deviations
+ all_matched = []
+ total_gt_assemblies = 0
+
+ gt_assemblies = ass_true_dict.items()
+ if with_tqdm:
+ gt_assemblies = tqdm(gt_assemblies)
+
+ for ind, gt_assembly in gt_assemblies:
+ pred_assemblies = ass_pred_dict.get(ind, [])
+ num_gt, matched = match_assemblies(
+ pred_assemblies,
+ gt_assembly,
+ oks_sigma,
+ margin,
+ symmetric_kpts,
+ greedy_matching,
+ )
+ all_matched.extend(matched)
+ total_gt_assemblies += num_gt
+
+ if not all_matched:
+ return {
+ "precisions": np.array([]),
+ "recalls": np.array([]),
+ "mAP": 0.0,
+ "mAR": 0.0,
+ }
+
+ conf_pred = np.asarray([match.score for match in all_matched])
+ idx = np.argsort(-conf_pred, kind="mergesort")
+ # Sort matching score (OKS) in descending order of assembly affinity
+ oks = np.asarray([match.oks for match in all_matched])[idx]
+ recall_thresholds = np.linspace(0, 1, 101)
+ precisions = []
+ recalls = []
+ for t in oks_thresholds:
+ p, r = _compute_precision_and_recall(
+ total_gt_assemblies, oks, t, recall_thresholds
+ )
+ precisions.append(p)
+ recalls.append(r)
+
+ precisions = np.asarray(precisions)
+ recalls = np.asarray(recalls)
+ return {
+ "precisions": precisions,
+ "recalls": recalls,
+ "mAP": precisions.mean(),
+ "mAR": recalls.mean(),
+ }
diff --git a/deeplabcut/core/metrics/__init__.py b/deeplabcut/core/metrics/__init__.py
new file mode 100644
index 0000000000..94397de57a
--- /dev/null
+++ b/deeplabcut/core/metrics/__init__.py
@@ -0,0 +1,13 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from .api import compute_metrics, prepare_evaluation_data
+from .bbox import compute_bbox_metrics
+from .identity import compute_identity_scores
diff --git a/deeplabcut/core/metrics/api.py b/deeplabcut/core/metrics/api.py
new file mode 100644
index 0000000000..75d4e7bbcb
--- /dev/null
+++ b/deeplabcut/core/metrics/api.py
@@ -0,0 +1,176 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""API methods to get metrics for deep learning models"""
+from __future__ import annotations
+
+import numpy as np
+
+import deeplabcut.core.metrics.distance_metrics as distance_metrics
+
+
+def compute_metrics(
+ ground_truth: dict[str, np.ndarray],
+ predictions: dict[str, np.ndarray],
+ single_animal: bool = False,
+ unique_bodypart_gt: dict[str, np.ndarray] | None = None,
+ unique_bodypart_poses: dict[str, np.ndarray] | None = None,
+ pcutoff: float = -1,
+ oks_bbox_margin: int = 0,
+ oks_sigma: float = 0.1,
+ per_keypoint_rmse: bool = False,
+ compute_detection_rmse: bool = True,
+) -> dict:
+ """Computes pose estimation performance metrics
+
+ Given ground truth pose labels and predictions on a dataset, computes RMSE and pose
+ mAP/mAR using OKS.
+
+ The image paths in the ground_truth dict must be the same as the ones in the
+ predictions dict.
+
+ Single animal RMSE is computed by simply calculating the Euclidean distance between
+ each ground truth keypoint and the corresponding prediction.
+
+ Multi-animal RMSE is computed differently: predictions are first matched to ground
+ truth individuals using greedy OKS matching. OKS (or object keypoint similarity) is
+ a similarity metric for keypoints (you can read more about it and its definition
+ here: https://cocodataset.org/#keypoints-eval). RMSE is then computed only between
+ predictions and the ground truth pose they are matched to, only when the OKS is
+ greater than a small threshold. Predictions that cannot be matched to any ground
+ truth with non-zero OKS are not used to compute RMSE.
+
+ Args:
+ ground_truth: The ground truth pose for which to compute metrics in the dataset.
+ This should be a dictionary mapping strings (image UIDs, such as image
+ paths) to ground truth pose for the image. The pose arrays should be
+ in the format (num_individuals, num_bodyparts, 3), where the 3 values are
+ x, y and visibility. The ``num_individuals`` corresponds to the number of
+ individuals labeled in each image.
+ predictions: The predicted poses for which to compute metrics in the dataset.
+ This should be a dictionary mapping strings (image UIDs, such as image
+ paths) to pose predictions for the image. The pose arrays should be
+ in the format (num_predictions, num_bodyparts, 3), where the 3 values are
+ x, y and score. The number of predictions can be different to the number of
+ ground truth individuals labeled for an image.
+ single_animal: Whether the metrics are being computed on a single-animal or
+ multi-animal dataset. This has an impact on RMSE computation.
+ unique_bodypart_gt: If unique bodyparts are defined for the dataset, they should
+ be contained in this dict in the same format as the ``ground_truth`` dict.
+ unique_bodypart_poses: If unique bodyparts are defined for the dataset, the
+ predictions should be contained in this dict in the same format as the
+ ``predictions`` dict.
+ pcutoff: The threshold to compute the "rmse_cutoff" score (RMSE of all
+ predictions with score above the cutoff).
+ oks_bbox_margin: The margin to add around keypoints to compute the area for OKS
+ computation.
+ oks_sigma: The OKS sigma to use to compute pose.
+ per_keypoint_rmse: Compute per-keypoint RMSE values.
+ compute_detection_rmse: Computes detection RMSE (without animal assembly) if the
+ predictions are from a multi-animal model.
+
+ Returns:
+ A dictionary containing keys "rmse", "rmse_cutoff", "mAP" and "mAR" mapping
+ to those metrics on the given dataset.
+
+ If unique bodyparts are given, two extra keys "rmse_unique_bodyparts" and
+ "rmse_pcutoff_unique_bodyparts" are also returned, containing the metrics for
+ the unique bodyparts head.
+
+ If `per_keypoint_evaluation=True`, "keypoint_rmse", "keypoint_rmse_cutoff" (and
+ optionally "unique_keypoint_rmse" and "unique_keypoint_rmse_cutoff") keys are
+ added, containing a list of floats representing the RMSE for each keypoint.
+
+ Examples:
+ >>> # Define the p-cutoff, prediction, and target DataFrames
+ >>> pcutoff = 0.5
+ >>> ground_truth = {"img0": np.array([[[1.0, 1.0, 2.0], ...], ...]), ...}
+ >>> predictions = {"img0": np.array([[[2.0, 1.0, 0.4], ...], ...]), ...}
+ >>> scores = compute_metrics(ground_truth, predictions, pcutoff=pcutoff)
+ >>> print(scores)
+ {
+ "rmse": 1.0,
+ "rmse_pcutoff": 0.0,
+ 'mAP': 84.2,
+ 'mAR': 74.5
+ } # Sample output scores
+ """
+ data = prepare_evaluation_data(ground_truth, predictions)
+ oks_scores = distance_metrics.compute_oks(
+ data=data,
+ oks_sigma=oks_sigma,
+ oks_bbox_margin=oks_bbox_margin,
+ )
+
+ data_unique = None
+ if unique_bodypart_gt is not None:
+ assert unique_bodypart_poses is not None
+ data_unique = prepare_evaluation_data(unique_bodypart_gt, unique_bodypart_poses)
+
+ rmse_scores = distance_metrics.compute_rmse(
+ data,
+ single_animal,
+ pcutoff,
+ data_unique=data_unique,
+ per_keypoint_results=per_keypoint_rmse,
+ )
+ results = dict(**rmse_scores, **oks_scores)
+
+ if compute_detection_rmse and not single_animal:
+ det_rmse, det_rmse_p = distance_metrics.compute_detection_rmse(
+ data, pcutoff, data_unique=data_unique,
+ )
+ results["rmse_detections"] = det_rmse
+ results["rmse_detections_pcutoff"] = det_rmse_p
+
+ return results
+
+
+def prepare_evaluation_data(
+ ground_truth: dict[str, np.ndarray],
+ predictions: dict[str, np.ndarray],
+) -> list[tuple[np.ndarray, np.ndarray]]:
+ """Prepares predictions and ground truth pose to compute metrics.
+
+ Only keeps ground truth and predicted assemblies with at least 2 valid keypoints.
+ Sets the coordinates for all keypoints that aren't visible (for ground truth,
+ visibility <= 0 and for predictions score <= 0) to ``np.nan``.
+
+ Sorts valid predictions by score.
+
+ Args:
+ ground_truth: For each image, the GT of shape (n_idv, n_bpt, 3).
+ predictions: For each image, the pose predictions of shape (n_pred, n_bpt, 3).
+
+ Returns:
+ A list containing (ground truth pose, predicted pose) for each image in the
+ dataset, where the predicted pose is sorted from highest to lowest score.
+ """
+ pose_data = []
+ for image, gt in ground_truth.items():
+ gt = gt.copy()
+ gt[gt[..., 2] <= 0] = np.nan
+
+ # only keep ground truth pose with at least one keypoint
+ gt_mask = np.any(np.all(~np.isnan(gt), axis=-1), axis=-1)
+ gt = gt[gt_mask]
+
+ pred = predictions[image][..., :3].copy() # PAF have 5 values; keep xy + score
+ pred[pred[..., 2] < 0] = np.nan
+
+ # only keep predicted pose with at least two keypoints
+ pred_mask = np.any(np.all(~np.isnan(pred), axis=-1), axis=-1)
+ pred = pred[pred_mask]
+
+ scores = np.nanmean(pred[:, :, 2], axis=-1)
+ pred_order = np.argsort(-scores, kind="mergesort")
+ pose_data.append((gt, pred[pred_order]))
+
+ return pose_data
diff --git a/deeplabcut/core/metrics/bbox.py b/deeplabcut/core/metrics/bbox.py
new file mode 100644
index 0000000000..2ee60cdc86
--- /dev/null
+++ b/deeplabcut/core/metrics/bbox.py
@@ -0,0 +1,159 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Bounding box metrics
+
+Metrics are currently computed using pycocotools, which can be installed with `pypi`
+(see https://github.com/ppwwyyxx/cocoapi/tree/master).
+"""
+from __future__ import annotations
+
+from unittest.mock import Mock, patch
+
+import numpy as np
+
+try:
+ from pycocotools.coco import COCO
+ from pycocotools.cocoeval import COCOeval
+
+ with_pycocotools = True
+except ModuleNotFoundError as err:
+ with_pycocotools = False
+
+
+@patch("pycocotools.coco.print", Mock())
+@patch("pycocotools.cocoeval.print", Mock())
+def compute_bbox_metrics(
+ ground_truth: dict[str, dict],
+ detections: dict[str, dict],
+) -> dict[str, float]:
+ """Computes bbox mAP and mAR metrics for bounding boxes.
+
+ Args:
+ ground_truth: A dictionary mapping image UIDs (such as image paths or filenames)
+ to a ground truth labels dict. The labels dict should contain the keys
+ "width" (image width), "height" (image height) and "bboxes" (a numpy array
+ of shape (num_gt_bboxes, 4) containing the ground truth bounding boxes in
+ format xywh).
+ detections: A dictionary mapping image UIDs (such as image paths or filenames)
+ to a predicted bounding box dict. The detections dict should contain the
+ keys "bboxes" (a numpy array of shape (num_detected_bboxes, 4) containing
+ the predicted bounding boxes in format xywh) and "scores" (a numpy array of
+ length num_detected_bboxes containing the confidence score for each
+ predicted bounding box).
+
+ Returns:
+ The bounding box mAP/mAR metrics in a dictionary.
+
+ Raises:
+ ModuleNotFoundError: if ``pycocotools`` is not installed
+ ValueError: if there are mismatches in the keys of ground_truth and detections
+ """
+ if not with_pycocotools:
+ raise ModuleNotFoundError("pycocotools not installed! can't compute bbox mAP")
+
+ if len(detections) != len(ground_truth):
+ raise ValueError()
+
+ coco = COCO()
+ coco.dataset["annotations"] = []
+ coco.dataset["categories"] = [{"id": 1, "name": "animals", "supercategory": "obj"}]
+ coco.dataset["images"] = []
+ predictions = []
+ for idx, (img, gt) in enumerate(ground_truth.items()):
+ img_id = idx + 1
+ coco.dataset["images"].append(
+ {
+ "id": img_id,
+ "file_name": img,
+ "width": gt["width"],
+ "height": gt["height"],
+ }
+ )
+ for bbox in gt["bboxes"][:, :4]:
+ ann_id = len(coco.dataset["annotations"]) + 1
+ coco.dataset["annotations"].append(
+ {
+ "id": ann_id,
+ "image_id": img_id,
+ "category_id": 1,
+ "area": max(1, (bbox[2] * bbox[3]).item()),
+ "bbox": bbox,
+ "iscrowd": 0,
+ }
+ )
+
+ for bbox, score in zip(detections[img]["bboxes"], detections[img]["scores"]):
+ predictions.append(np.array([img_id, *bbox, score, 1]))
+
+ if len(predictions) == 0:
+ return {
+ "mAP@50:95": 0.0,
+ "mAP@50": 0.0,
+ "mAP@75": 0.0,
+ "mAR@50:95": 0.0,
+ "mAR@50": 0.0,
+ "mAR@75": 0.0,
+ }
+
+ predictions = np.stack(predictions, axis=0)
+ coco.createIndex()
+ coco_det = coco.loadRes(predictions)
+ coco_eval = COCOeval(coco, coco_det, iouType="bbox")
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ return {
+ name: val
+ for name, val in [
+ _get_metric(coco_eval, recall=False),
+ _get_metric(coco_eval, recall=False, iou_threshold=0.5),
+ _get_metric(coco_eval, recall=False, iou_threshold=0.75),
+ _get_metric(coco_eval, recall=True),
+ _get_metric(coco_eval, recall=True, iou_threshold=0.5),
+ _get_metric(coco_eval, recall=True, iou_threshold=0.75),
+ ]
+ }
+
+
+def _get_metric(
+ coco_eval: COCOeval,
+ recall: bool = False,
+ iou_threshold: float | None = None,
+ area_rng: str = "all",
+ max_dets: int = 100,
+) -> tuple[str, float]:
+ metric_name = "mAR" if recall else "mAP"
+ if iou_threshold is not None:
+ thresh = f"{int(100 * iou_threshold)}"
+ else:
+ low, high = coco_eval.params.iouThrs[0], coco_eval.params.iouThrs[-1]
+ thresh = f"{int(100 * low)}:{int(100 * high)}"
+
+ aind = [i for i, aRng in enumerate(coco_eval.params.areaRngLbl) if aRng == area_rng]
+ mind = [i for i, mDet in enumerate(coco_eval.params.maxDets) if mDet == max_dets]
+ if recall:
+ s = coco_eval.eval["recall"]
+ if iou_threshold is not None:
+ t = np.where(iou_threshold == coco_eval.params.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, aind, mind]
+ else:
+ s = coco_eval.eval["precision"]
+ if iou_threshold is not None:
+ t = np.where(iou_threshold == coco_eval.params.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, :, aind, mind]
+
+ if len(s[s > -1]) == 0:
+ mean_s = -1
+ else:
+ mean_s = 100 * np.mean(s[s > -1]).item()
+
+ return f"{metric_name}@{thresh}", mean_s
diff --git a/deeplabcut/core/metrics/distance_metrics.py b/deeplabcut/core/metrics/distance_metrics.py
new file mode 100644
index 0000000000..ac2fbc04c1
--- /dev/null
+++ b/deeplabcut/core/metrics/distance_metrics.py
@@ -0,0 +1,459 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Implementations of methods to compute distance metrics such as RMSE or OKS"""
+from __future__ import annotations
+
+import numpy as np
+
+import deeplabcut.core.metrics.matching as matching
+from deeplabcut.core.crossvalutils import find_closest_neighbors
+from deeplabcut.core.inferenceutils import calc_object_keypoint_similarity
+
+
+def compute_oks_matrix(
+ ground_truth: np.ndarray,
+ predictions: np.ndarray,
+ oks_sigma: float,
+ oks_bbox_margin: float = 0.0,
+) -> np.ndarray:
+ """Computes the OKS score for each (prediction, gt) pair in an image
+
+ Args:
+ ground_truth: The GT poses for an image, shape (n_individuals, n_kpts, 2)
+ predictions: The predicted poses in the image, shape (n_pred, n_kpts, 2)
+ oks_sigma: The sigma value to use to compute OKS
+ oks_bbox_margin: The margin to add around keypoints when computing the area.
+ FIXME(niels) We should allow the use of ground truth bboxes to get area
+
+ Returns:
+ A matrix of shape (n_pred, n_kpts) where entry (i, j) is the OKS between
+ prediction i and ground truth j.
+ """
+ oks_matrix = np.zeros((len(predictions), len(ground_truth)))
+ for pred_idx, pred in enumerate(predictions):
+ for gt_idx, gt in enumerate(ground_truth):
+ oks_matrix[pred_idx, gt_idx] = calc_object_keypoint_similarity(
+ pred[:, :2],
+ gt[:, :2],
+ sigma=oks_sigma,
+ margin=oks_bbox_margin,
+ )
+
+ return oks_matrix
+
+
+def compute_oks(
+ data: list[tuple[np.ndarray, np.ndarray]],
+ oks_bbox_margin: float = 0.0,
+ oks_sigma: float = 0.1,
+ oks_thresholds: np.ndarray | None = None,
+ oks_recall_thresholds: np.ndarray | None = None,
+) -> dict[str, float]:
+ """Computes the OKS for pose at different thresholds.
+
+ Args:
+ data: The data for which to compute OKS mAP: a list containing (gt_poses,
+ predicted_poses) tuples, where gt_pose is an array of shape
+ (num_gt_individuals, num_bpts, 3) and predicted_poses is an array of shape
+ (num_predictions, num_bpts, 3). For the GT, the 3 coordinates are (x, y,
+ visibility) while for the pose they are (x, y, confidence score).
+ oks_sigma: The OKS sigma to use to compute pose.
+ oks_bbox_margin: The margin to add around keypoints to compute the area for OKS
+ computation.
+ oks_thresholds: The OKS thresholds at which to compute AP. If None, defaults to
+ (0.5, 0.55, 0.6, ..., 0.9, 0.95).
+ oks_recall_thresholds: The recall thresholds to use to compute mAP. If None,
+ defaults to the same default values used in pycocotools.
+
+ Returns:
+ A dictionary containing mAP and mAR scores.
+ """
+ if oks_thresholds is None:
+ oks_thresholds = np.linspace(0.5, 0.95, 10)
+
+ if oks_recall_thresholds is None:
+ oks_recall_thresholds = np.linspace(
+ start=0.0,
+ stop=1.00,
+ num=int(np.round((1.00 - 0.0) / 0.01)) + 1,
+ endpoint=True,
+ )
+
+ total_gt = 0
+ pose_data = []
+ for gt, pred in data:
+ # filter data to only keep individuals with at least 2 valid keypoints
+ gt = gt[np.sum(np.all(~np.isnan(gt), axis=-1), axis=-1) > 1]
+ pred = pred[np.sum(np.all(~np.isnan(pred), axis=-1), axis=-1) > 1]
+
+ oks_matrix = compute_oks_matrix(
+ gt[:, :, :2],
+ pred[:, :, :2],
+ oks_sigma=oks_sigma,
+ oks_bbox_margin=oks_bbox_margin,
+ )
+
+ total_gt += len(gt)
+ pose_data.append((gt, pred, oks_matrix))
+
+ precisions, recalls = [], []
+ for oks_threshold in oks_thresholds:
+ matches = []
+ for gt, pred, oks_matrix in pose_data:
+ image_matches = matching.match_greedy_oks(
+ gt,
+ pred,
+ oks_matrix=oks_matrix,
+ oks_threshold=oks_threshold,
+ )
+ matches.extend(image_matches)
+
+ if len(matches) == 0: # no predictions -> precision 0, recall 0
+ return {"mAP": 0, "mAR": 0}
+
+ scores = np.asarray([m.score for m in matches])
+ match_order = np.argsort(-scores, kind="mergesort")
+ oks_values = np.asarray([m.oks for m in matches])
+ oks_values = oks_values[match_order]
+
+ tp = np.cumsum(oks_values >= oks_threshold)
+ fp = np.cumsum(oks_values < oks_threshold)
+ rc = tp / total_gt
+ pr = tp / (fp + tp + np.spacing(1))
+ recall = rc[-1]
+
+ # Guarantee precision decreases monotonically, see
+ # https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173
+ for i in range(len(pr) - 1, 0, -1):
+ if pr[i] > pr[i - 1]:
+ pr[i - 1] = pr[i]
+
+ inds_rc = np.searchsorted(rc, oks_recall_thresholds, side="left")
+ precision = np.zeros(inds_rc.shape)
+ valid = inds_rc < len(pr)
+ precision[valid] = pr[inds_rc[valid]]
+
+ precisions.append(precision)
+ recalls.append(recall)
+
+ precisions = np.asarray(precisions)
+ recalls = np.asarray(recalls)
+ return {
+ "mAP": 100 * precisions.mean().item(),
+ "mAR": 100 * recalls.mean().item(),
+ }
+
+
+def match_predictions_for_rmse(
+ data: list[tuple[np.ndarray, np.ndarray]],
+ single_animal: bool,
+ oks_bbox_margin: float = 0.0,
+) -> list[matching.PotentialMatch]:
+ """Matches GT keypoints to predictions to compute RMSE.
+
+ Single animal RMSE is computed by simply calculating the distance between each
+ ground truth keypoint and the corresponding prediction.
+
+ Multi-animal RMSE is computed differently: predictions are first matched to ground
+ truth individuals using greedy OKS matching. RMSE is then computed only between
+ predictions and the ground truth pose they are matched to, only when the OKS is
+ non-zero (greater than a small threshold). Predictions that cannot be matched to
+ any ground truth with non-zero OKS are not used to compute RMSE.
+
+ Args:
+ data: The data for which to compute RMSE. This is a list containing (gt_poses,
+ predicted_poses), where gt_pose is an array of shape (num_gt_individuals,
+ num_bpts, 3) and predicted_poses is an array of shape (num_predictions,
+ num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for
+ the pose they are (x, y, confidence score).
+ single_animal: Whether this is a single animal dataset.
+ oks_bbox_margin: When single_animal is False, predictions are matched to GT
+ using OKS. This is the margin used to apply when computing the bbox from
+ the pose to compute OKS.
+
+ Returns:
+ A list containing the predictions matched to ground truth.
+
+ Raises:
+ ValueError: If `single_animal=True` but more than one ground truth/predicted
+ keypoint is found for an entry
+ """
+ matches = []
+ for gt, pred in data:
+ if single_animal:
+ if gt.shape[0] > 1 or pred.shape[0] > 1:
+ raise ValueError(
+ "At most 1 individual and 1 prediction can be given when computing "
+ f"single animal RMSE. Found gt={gt.shape}, pred={pred.shape}"
+ )
+
+ image_matches = []
+ if gt.shape[0] == 1 and pred.shape[0] == 1:
+ match = matching.PotentialMatch.from_pose(pred[0])
+ match.match(gt[0], oks=float("nan")) # OKS not needed for RMSE
+ image_matches.append(match)
+ else:
+ oks_matrix = compute_oks_matrix(
+ gt[:, :, :2],
+ pred[:, :, :2],
+ oks_sigma=0.1,
+ oks_bbox_margin=oks_bbox_margin,
+ )
+ image_matches = matching.match_greedy_oks(
+ gt,
+ pred,
+ oks_matrix=oks_matrix,
+ oks_threshold=1e-6,
+ )
+
+ matches.extend(image_matches)
+
+ return matches
+
+
+def compute_rmse(
+ data: list[tuple[np.ndarray, np.ndarray]],
+ single_animal: bool,
+ pcutoff: float | list[float],
+ data_unique: list[tuple[np.ndarray, np.ndarray]] | None = None,
+ per_keypoint_results: bool = False,
+ oks_bbox_margin: float = 0.0,
+) -> dict[str, float]:
+ """Computes the RMSE for pose predictions.
+
+ Single animal RMSE is computed by simply calculating the distance between each
+ ground truth keypoint and the corresponding prediction.
+
+ Multi-animal RMSE is computed differently: predictions are first matched to ground
+ truth individuals using greedy OKS matching. RMSE is then computed only between
+ predictions and the ground truth pose they are matched to, only when the OKS is
+ non-zero (greater than a small threshold). Predictions that cannot be matched to
+ any ground truth with non-zero OKS are not used to compute RMSE.
+
+ Args:
+ data: The data for which to compute RMSE. This is a list containing (gt_poses,
+ predicted_poses), where gt_pose is an array of shape (num_gt_individuals,
+ num_bpts, 3) and predicted_poses is an array of shape (num_predictions,
+ num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for
+ the pose they are (x, y, confidence score).
+ single_animal: Whether this is a single animal dataset.
+ pcutoff: The p-cutoff to use to compute RMSE. If a list, the cutoff for each
+ bodypart is set individually. The list must have length num_bodyparts +
+ num_unique_bodyparts.
+ data_unique: Unique bodypart ground truth and predictions to include in RMSE
+ computations, if there are any such bodyparts.
+ per_keypoint_results: Whether to compute the RMSE for each individual keypoint.
+ oks_bbox_margin: When single_animal is False, predictions are matched to GT
+ using OKS. This is the margin used to apply when computing the bbox from
+ the pose to compute OKS.
+
+ Returns:
+ A dictionary matching metric names to values. It will at least have "rmse" and
+ "rmse_cutoff" keys. If `per_keypoint_results=True` and there is at least one
+ non-NaN pixel error it will also contain "rmse_keypoint_X" and
+ "rmse_cutoff_keypoint_X" keys for each bodypart, where X is the index of the
+ bodypart.
+
+ Raises:
+ ValueError: If `single_animal=True` but more than one ground truth/predicted
+ keypoint is found for an entry
+ """
+ matches = match_predictions_for_rmse(data, single_animal, oks_bbox_margin)
+ pixel_errors, keypoint_scores = None, None
+ if len(matches) > 0:
+ pixel_errors = np.stack([m.pixel_errors() for m in matches])
+ keypoint_scores = np.stack([m.keypoint_scores() for m in matches])
+
+ error, support, cutoff_error, cutoff_support = 0, 0, 0, 0
+ if pixel_errors is not None:
+ bpt_cutoffs = pcutoff
+ if not isinstance(pcutoff, (int, float)):
+ bpt_cutoffs = pcutoff[:pixel_errors.shape[1]]
+
+ error, support, cutoff_error, cutoff_support = collect_pixel_errors(
+ pixel_errors, keypoint_scores, bpt_cutoffs,
+ )
+
+ unique_pixel_errors, unique_keypoint_scores = None, None
+ if data_unique is not None:
+ u_matches = match_predictions_for_rmse(data_unique, single_animal=True)
+ if len(u_matches) > 0:
+ unique_pixel_errors = np.stack([m.pixel_errors() for m in u_matches])
+ unique_keypoint_scores = np.stack([m.keypoint_scores() for m in u_matches])
+
+ bpt_cutoffs = pcutoff
+ if not isinstance(pcutoff, (int, float)):
+ bpt_cutoffs = pcutoff[-unique_pixel_errors.shape[1]:]
+ u_error, u_support, u_cutoff_error, u_cutoff_support = collect_pixel_errors(
+ unique_pixel_errors, unique_keypoint_scores, bpt_cutoffs,
+ )
+ error += u_error
+ support += u_support
+ cutoff_error += u_cutoff_error
+ cutoff_support += u_cutoff_support
+
+ results = dict(rmse=float("nan"), rmse_pcutoff=float("nan"))
+ if support > 0:
+ results["rmse"] = float(error / support)
+ if cutoff_support > 0:
+ results["rmse_pcutoff"] = float(cutoff_error / cutoff_support)
+
+ if per_keypoint_results:
+ bodypart_errors = [("rmse_keypoint", pixel_errors)]
+ if unique_pixel_errors is not None:
+ bodypart_errors.append(("rmse_unique_keypoint", unique_pixel_errors))
+
+ for key_prefix, bpt_errors in bodypart_errors:
+ for idx, keypoint_error in enumerate(bpt_errors.T):
+ rmse = float("nan")
+ if np.any(~np.isnan(keypoint_error)):
+ rmse = np.nanmean(keypoint_error).item()
+ results[f"{key_prefix}_{idx}"] = float(rmse)
+
+ return results
+
+
+def compute_detection_rmse(
+ data: list[tuple[np.ndarray, np.ndarray]],
+ pcutoff: float | list[float],
+ data_unique: list[tuple[np.ndarray, np.ndarray]] | None = None,
+) -> tuple[float, float]:
+ """Computes the detection RMSE for pose predictions.
+
+ The detection RMSE score does not take individual assemblies into account. It only
+ judges the performance of the detections, matching each predicted keypoint to the
+ closest ground truth for each bodypart.
+
+ This is the same way multi-animal RMSE was computed in DeepLabCut 2.X.
+
+ Args:
+ data: The data for which to compute RMSE. This is a list containing (gt_poses,
+ predicted_poses), where gt_pose is an array of shape (num_gt_individuals,
+ num_bpts, 3) and predicted_poses is an array of shape (num_predictions,
+ num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for
+ the pose they are (x, y, confidence score).
+ pcutoff: The p-cutoff to use to compute RMSE. If a list, the cutoff for each
+ bodypart is set individually. The list must have length num_bodyparts +
+ num_unique_bodyparts.
+ data_unique: Unique bodypart ground truth and predictions to include in RMSE
+ computations, if there are any such bodyparts.
+
+ Returns:
+ The detection RMSE and detection RMSE after removing all detections with a
+ score below the pcutoff.
+ """
+ distances = []
+ distances_cutoff = []
+ for image_gt, image_pred in data:
+ image_gt = image_gt.transpose((1, 0, 2)) # to (num_bpts, num_gt_individuals, 3)
+ image_pred = image_pred.transpose((1, 0, 2)) # to (num_bpts, num_pred, 3)
+
+ for bpt_index, (bpt_gt, bpt_pred) in enumerate(zip(image_gt, image_pred)):
+ # filter NaNs and invalid values
+ bpt_gt = bpt_gt[~np.any(np.isnan(bpt_gt), axis=1)]
+ bpt_pred = bpt_pred[~np.any(np.isnan(bpt_pred), axis=1)]
+ if len(bpt_gt) == 0 or len(bpt_pred) == 0:
+ continue
+
+ if isinstance(pcutoff, (int, float)):
+ bpt_pcutoff = pcutoff
+ else:
+ bpt_pcutoff = pcutoff[bpt_index]
+
+ # assignment of predicted bodyparts to ground truth
+ neighbors = find_closest_neighbors(bpt_gt, bpt_pred, k=3)
+ for gt_index, pred_index in enumerate(neighbors):
+ if pred_index != -1:
+ gt = bpt_gt[gt_index]
+ pred = bpt_pred[pred_index]
+ dist = np.linalg.norm(gt[:2] - pred[:2])
+ distances.append(dist)
+
+ score = bpt_pred[pred_index, 2]
+ if score >= bpt_pcutoff:
+ distances_cutoff.append(dist)
+
+ if data_unique is not None:
+ for image_gt, image_pred in data_unique:
+ assert len(image_gt) <= 1 and len(image_pred) <= 1, (
+ f"Unique GT an predictions must have length 0 or 1! Found {image_gt.shape}, "
+ f"{image_pred.shape}."
+ )
+
+ if len(image_gt) == 1 and len(image_pred) == 1:
+ unique_gt, unique_pred = image_gt[0], image_pred[0]
+ num_unique = unique_gt.shape[0]
+ unique_cutoffs = pcutoff
+ if not isinstance(pcutoff, (int, float)):
+ unique_cutoffs = pcutoff[-num_unique:]
+
+ for bpt_index, (gt, pred) in enumerate(zip(unique_gt, unique_pred)):
+ dist = np.linalg.norm(gt[:2] - pred[:2])
+ distances.append(dist)
+
+ score = pred[2]
+ if isinstance(pcutoff, (int, float)):
+ bpt_pcutoff = unique_cutoffs
+ else:
+ bpt_pcutoff = unique_cutoffs[bpt_index]
+
+ if score >= bpt_pcutoff:
+ distances_cutoff.append(dist)
+
+ rmse, rmse_cutoff = float("nan"), float("nan")
+ if len(distances) == 0:
+ return rmse, rmse_cutoff
+
+ distances = np.stack(distances)
+ if np.any(~np.isnan(distances)):
+ rmse = float(np.nanmean(distances).item())
+
+ if len(distances_cutoff) > 0:
+ distances_cutoff = np.stack(distances_cutoff)
+ if np.any(~np.isnan(distances_cutoff)):
+ rmse_cutoff = float(np.nanmean(distances_cutoff).item())
+
+ return rmse, rmse_cutoff
+
+
+def collect_pixel_errors(
+ pixel_errors: np.ndarray,
+ keypoint_scores: np.ndarray,
+ pcutoff: float,
+) -> tuple[float, int, float, int]:
+ """Collects pixel errors for RMSE computation
+
+ Args:
+ pixel_errors: The pixel errors to collect, of shape (num_matches, num_bodyparts)
+ keypoint_scores: The scores corresponding to the pixel errors, of shape
+ (num_matches, num_bodyparts).
+ pcutoff: The pcutoff to use when computing cutoff RMSE.
+
+ Returns: error, support, cutoff_error, support_cutoff
+ error: The sum of all pixel errors.
+ support: The number of valid pixel errors.
+ cutoff_error: The sum of all pixel errors with score > pcutoff.
+ support_cutoff: The number of valid pixel errors with score > pcutoff.
+ """
+ error = 0.0
+ cutoff_error = 0.0
+ support = np.sum(~np.isnan(pixel_errors)).item()
+ support_cutoff = 0
+ if support > 0:
+ error += np.nansum(pixel_errors).item()
+
+ cutoff_mask = keypoint_scores >= pcutoff
+ cutoff_pixel_errors = pixel_errors[cutoff_mask]
+ support_cutoff = np.sum(~np.isnan(cutoff_pixel_errors)).item()
+ if support_cutoff > 0:
+ cutoff_error = np.nansum(cutoff_pixel_errors).item()
+
+ return error, support, cutoff_error, support_cutoff
diff --git a/deeplabcut/core/metrics/identity.py b/deeplabcut/core/metrics/identity.py
new file mode 100644
index 0000000000..1720bfdffa
--- /dev/null
+++ b/deeplabcut/core/metrics/identity.py
@@ -0,0 +1,92 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Implementations of methods to compute identity prediction accuracy"""
+from __future__ import annotations
+
+import numpy as np
+from sklearn.metrics import accuracy_score
+
+from deeplabcut.core.crossvalutils import find_closest_neighbors
+
+
+def compute_identity_scores(
+ individuals: list[str],
+ bodyparts: list[str],
+ predictions: dict[str, np.ndarray],
+ identity_scores: dict[str, np.ndarray],
+ ground_truth: dict[str, np.ndarray],
+) -> dict[str, float]:
+ """
+ FIXME: With DLCRNet all heatmap "peaks" above 0.01 were kept, with 1 keypoint and
+ 1 identity score map per peak. Then, for each ground truth keypoint, we selected
+ the prediction closest to it, and evaluated the identity score in that position.
+ This is no longer the case, as we're now evaluating after assembly. So we only
+ have num_individuals assemblies.
+
+ Args:
+ individuals:
+ bodyparts:
+ predictions: (num_assemblies, num_bodyparts, 3)
+ identity_scores: (num_assemblies, num_bodyparts, num_individuals)
+ ground_truth: (num_individuals, num_bodyparts, 3)
+
+ Returns:
+
+ """
+ if not len(predictions) == len(ground_truth):
+ raise ValueError("Mismatch between number of predictions and ground truth")
+
+ all_bpts = np.asarray(len(individuals) * bodyparts)
+ ids = np.full((len(predictions), len(all_bpts), 2), np.nan)
+ for i, (image, pred) in enumerate(predictions.items()):
+ for j in range(len(individuals)):
+ for k in range(len(bodyparts)):
+ bpt_idx = len(bodyparts) * j + k
+ ids[i, bpt_idx, 0] = j
+
+ # set keypoints that aren't visible to NaN
+ gt = ground_truth[image].copy()
+ gt[gt[..., 2] <= 0, :2] = np.nan
+ gt = gt[..., :2]
+
+ id_scores = identity_scores[image]
+
+ # reorder to (bodypart, individual, ...)
+ gt = gt.transpose((1, 0, 2))
+ pred = pred.transpose((1, 0, 2))[..., :2]
+ id_scores = id_scores.transpose((1, 0, 2))
+ for bpt, bpt_gt, bpt_pred, bpt_id_scores in zip(bodyparts, gt, pred, id_scores):
+ # assign ground truth keypoints to the closest prediction, so the ID score
+ # is the closest possible to the ID score computed with "ground truth"
+ indices_gt = np.flatnonzero(np.all(~np.isnan(bpt_gt), axis=1))
+
+ # Remove NaN predictions from the bodypart predictions
+ indices_pred = np.all(np.isfinite(bpt_pred), axis=1)
+ bpt_pred = bpt_pred[indices_pred]
+ bpt_id_scores = bpt_id_scores[indices_pred]
+
+ neighbors = find_closest_neighbors(bpt_gt[indices_gt], bpt_pred, k=3)
+ found = neighbors != -1
+ indices = np.flatnonzero(all_bpts == bpt)
+ # Get the predicted identity of each bodypart by taking the argmax
+ ids[i, indices[indices_gt[found]], 1] = np.argmax(
+ bpt_id_scores[neighbors[found]], axis=1
+ )
+
+ ids = ids.reshape((len(predictions), len(individuals), len(bodyparts), 2))
+ results = {}
+ for i, bpt in enumerate(bodyparts):
+ temp = ids[:, :, i].reshape((-1, 2))
+ valid = np.isfinite(temp).all(axis=1)
+ y_true, y_pred = temp[valid].T
+ results[f"{bpt}_accuracy"] = accuracy_score(y_true, y_pred)
+
+ return results
diff --git a/deeplabcut/core/metrics/matching.py b/deeplabcut/core/metrics/matching.py
new file mode 100644
index 0000000000..95b28ebe5b
--- /dev/null
+++ b/deeplabcut/core/metrics/matching.py
@@ -0,0 +1,169 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Algorithms to match predictions to ground truth labels"""
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import numpy as np
+
+
+@dataclass
+class PotentialMatch:
+ """A potential match between predicted pose and ground truth pose.
+
+ Args:
+ pose: An array of shape (num_bodyparts, 3)
+ score: The score for the prediction. This could be the mean of the confidence
+ score for each bodypart, or another value representing how confident the
+ model is that this assembly is correct.
+ gt: None if no ground truth pose was matched to the prediction. If defined, the
+ ground truth to which the prediction is matched. It should be of shape
+ (num_bodyparts, 3), where the 3 values are x, y and visibility.
+ oks: The OKS score between the pose and the ground truth.
+ """
+
+ pose: np.ndarray
+ score: float
+ gt: np.ndarray | None = None
+ oks: float = 0.0
+
+ def keypoint_scores(self) -> np.ndarray:
+ """Returns: The confidence score for each bodypart in the predicted pose."""
+ return self.pose[:, 2].copy()
+
+ def pixel_errors(self) -> np.ndarray:
+ """
+ Returns:
+ The distance (in pixels) between each predicted and ground truth bodypart.
+ If this prediction is unmatched, returns an array of length num_bodyparts
+ containing all NaNs.
+ """
+ if self.gt is None:
+ return np.full(len(self.pose), np.nan)
+
+ return np.linalg.norm(self.pose[:, :2] - self.gt[:, :2], axis=1)
+
+ def match(self, gt: np.ndarray, oks: float) -> None:
+ """Adds a ground truth match to this PotentialMatch
+
+ Args:
+ gt: The ground truth to which the prediction is matched. The ground truth
+ pose should be of shape (num_bodyparts, 3), where the 3 values are x, y
+ and visibility.
+ oks: The OKS similarity between the ground truth and this.
+ """
+ self.gt = gt
+ self.oks = oks
+
+ @classmethod
+ def from_pose(cls, pose: np.ndarray) -> "PotentialMatch":
+ assert len(pose.shape) == 2 # Must be pose for a single individual
+ scores = pose[:, 2]
+ if np.all(np.isnan(scores)):
+ raise ValueError(
+ "Cannot create a Match from a pose prediction where all scores are nan "
+ f"(pose={pose})"
+ )
+
+ return PotentialMatch(pose=pose, score=np.nanmean(scores).item())
+
+
+def match_greedy_oks(
+ ground_truth: np.ndarray,
+ predictions: np.ndarray,
+ oks_matrix: np.ndarray,
+ oks_threshold: float = 0.0,
+) -> list[PotentialMatch]:
+ """Greedy matching of ground truth individuals to predicted individuals using OKS
+
+ This is done in the same way as done in pycocotools. The predictions must be sorted
+ by score before being passed to this function.
+
+ Args:
+ ground_truth: The ground truth labels for an image, of shape (n_idv, n_bpt, 2)
+ predictions: The predictions for an image, of shape (n_idv, n_bpt, 2)
+ oks_matrix: A matrix of shape (n_pred, n_kpts) where entry (i, j) is the OKS
+ between prediction i and ground truth j.
+ oks_threshold: The min. OKS for a prediction to be matched to a GT pose
+
+ Returns:
+ A list containing a PotentialMatch for each predicted pose in the given
+ predictions.
+ """
+ matches = [PotentialMatch.from_pose(pose=pred) for pred in predictions]
+ matched_gt_indices = set()
+ for idx, pred in enumerate(predictions):
+ oks = oks_matrix[idx]
+ if np.all(np.isnan(oks)):
+ continue
+
+ ind_best = np.nanargmax(oks)
+
+ # if this gt already matched, continue
+ if ind_best in matched_gt_indices:
+ continue
+
+ # Only match the pred to the GT if the OKS value is above a given threshold
+ if oks[ind_best] < oks_threshold:
+ continue
+
+ matched_gt_indices.add(ind_best)
+ matches[idx].match(gt=ground_truth[ind_best], oks=oks[ind_best])
+
+ return matches
+
+
+def match_greedy_rmse(
+ ground_truth: np.ndarray,
+ predictions: np.ndarray,
+ keep_assemblies: bool = True,
+) -> list[PotentialMatch]:
+ """Greedy matching of ground truth individuals to predicted individuals using RMSE
+
+ The predictions must be sorted by score before being passed to this function.
+
+ Args:
+ ground_truth: The ground truth labels for an image, of shape (n_idv, n_bpt, 2)
+ predictions: The predictions for an image, of shape (n_idv, n_bpt, 2)
+ keep_assemblies: Whether to match predicted keypoints to ground truth keypoints
+ while enforcing that all bodyparts for a predicted individual are matched
+ to bodyparts from the same ground truth assembly. When set to False, this
+ corresponds to detection RMSE score.
+
+ Returns:
+ A list containing a PotentialMatch for each predicted pose in the given
+ predictions.
+ """
+ if not keep_assemblies:
+ raise NotImplementedError()
+
+ matches = [PotentialMatch.from_pose(pose=pred) for pred in predictions]
+ matched_gt_indices = set()
+ for idx, pred in enumerate(predictions):
+ bpt_distances = np.linalg.norm(pred[:, :2] - ground_truth[:, :, :2], axis=-1)
+ if np.all(np.isnan(bpt_distances)):
+ continue
+
+ distances = np.nanmean(bpt_distances, axis=-1)
+ ind_best = np.nanargmin(distances)
+
+ # if this gt already matched, continue
+ if ind_best in matched_gt_indices:
+ continue
+
+ matched_gt_indices.add(ind_best)
+ matches[idx].match(
+ gt=ground_truth[ind_best],
+ oks=float("nan"), # don't compute OKS here
+ )
+
+ return matches
diff --git a/deeplabcut/core/trackingutils.py b/deeplabcut/core/trackingutils.py
new file mode 100644
index 0000000000..565582dcd1
--- /dev/null
+++ b/deeplabcut/core/trackingutils.py
@@ -0,0 +1,840 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import abc
+import math
+import warnings
+from collections import defaultdict
+
+import numpy as np
+from filterpy.common import kinematic_kf
+from filterpy.kalman import KalmanFilter
+from matplotlib import patches
+from numba import jit
+from numba.core.errors import NumbaPerformanceWarning
+from scipy.optimize import linear_sum_assignment
+from scipy.stats import mode
+from tqdm import tqdm
+
+warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
+
+TRACK_METHODS = {
+ "box": "_bx",
+ "skeleton": "_sk",
+ "ellipse": "_el",
+ "transformer": "_tr",
+}
+
+
+def calc_iou(bbox1, bbox2):
+ x1 = max(bbox1[0], bbox2[0])
+ y1 = max(bbox1[1], bbox2[1])
+ x2 = min(bbox1[2], bbox2[2])
+ y2 = min(bbox1[3], bbox2[3])
+ w = max(0, x2 - x1)
+ h = max(0, y2 - y1)
+ wh = w * h
+ return wh / (
+ (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
+ + (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
+ - wh
+ )
+
+
+class BaseTracker:
+ """Base class for a constant-velocity Kalman filter-based tracker."""
+
+ n_trackers = 0
+
+ def __init__(self, dim, dim_z):
+ self.kf = kinematic_kf(
+ dim,
+ 1,
+ dim_z=dim_z,
+ order_by_dim=False,
+ )
+ self.id = self.__class__.n_trackers
+ self.__class__.n_trackers += 1
+ self.time_since_update = 0
+ self.age = 0
+ self.hits = 0
+ self.hit_streak = 0
+
+ def update(self, z):
+ self.time_since_update = 0
+ self.hits += 1
+ self.hit_streak += 1
+ self.kf.update(z)
+
+ def predict(self):
+ self.kf.predict()
+ self.age += 1
+ if self.time_since_update > 0:
+ self.hit_streak = 0
+ self.time_since_update += 1
+ return self.state
+
+ @property
+ def state(self):
+ return self.kf.x.squeeze()[: self.kf.dim_z]
+
+ @state.setter
+ def state(self, state):
+ self.kf.x[: self.kf.dim_z] = state
+
+
+class Ellipse:
+ def __init__(self, x, y, width, height, theta):
+ self.x = x
+ self.y = y
+ self.width = width
+ self.height = height
+ self.theta = theta # in radians
+ self._geometry = None
+
+ @property
+ def parameters(self):
+ return self.x, self.y, self.width, self.height, self.theta
+
+ @property
+ def aspect_ratio(self):
+ return max(self.width, self.height) / min(self.width, self.height)
+
+ def calc_similarity_with(self, other_ellipse):
+ max_dist = max(
+ self.height, self.width, other_ellipse.height, other_ellipse.width
+ )
+ dist = math.sqrt(
+ (self.x - other_ellipse.x) ** 2 + (self.y - other_ellipse.y) ** 2
+ )
+
+ if max_dist == 0:
+ max_dist = 1
+
+ cost1 = 1 - min(dist / max_dist, 1)
+ cost2 = abs(math.cos(self.theta - other_ellipse.theta))
+ return 0.8 * cost1 + 0.2 * cost2 * cost1
+
+ def contains_points(self, xy, tol=0.1):
+ ca = math.cos(self.theta)
+ sa = math.sin(self.theta)
+ x_demean = xy[:, 0] - self.x
+ y_demean = xy[:, 1] - self.y
+ return (
+ ((ca * x_demean + sa * y_demean) ** 2 / (0.5 * self.width) ** 2)
+ + ((sa * x_demean - ca * y_demean) ** 2 / (0.5 * self.height) ** 2)
+ ) <= 1 + tol
+
+ def draw(self, show_axes=True, ax=None, **kwargs):
+ import matplotlib.pyplot as plt
+ from matplotlib.lines import Line2D
+ from matplotlib.transforms import Affine2D
+
+ if ax is None:
+ ax = plt.subplot(111, aspect="equal")
+ el = patches.Ellipse(
+ xy=(self.x, self.y),
+ width=self.width,
+ height=self.height,
+ angle=np.rad2deg(self.theta),
+ **kwargs,
+ )
+ ax.add_patch(el)
+ if show_axes:
+ major = Line2D([-self.width / 2, self.width / 2], [0, 0], lw=3, zorder=3)
+ minor = Line2D([0, 0], [-self.height / 2, self.height / 2], lw=3, zorder=3)
+ trans = (
+ Affine2D().rotate(self.theta).translate(self.x, self.y) + ax.transData
+ )
+ major.set_transform(trans)
+ minor.set_transform(trans)
+ ax.add_artist(major)
+ ax.add_artist(minor)
+
+
+class EllipseFitter:
+ def __init__(self, sd=2):
+ self.sd = sd
+ self.x = None
+ self.y = None
+ self.params = None
+ self._coeffs = None
+
+ def fit(self, xy):
+ self.x, self.y = xy[np.isfinite(xy).all(axis=1)].T
+ if len(self.x) < 3:
+ return None
+ if self.sd:
+ self.params = self._fit_error(self.x, self.y, self.sd)
+ else:
+ self._coeffs = self._fit(self.x, self.y)
+ self.params = self.calc_parameters(self._coeffs)
+ if not np.isnan(self.params).any():
+ el = Ellipse(*self.params)
+ # Regularize by forcing AR <= 5
+ # max_ar = 5
+ # if el.aspect_ratio >= max_ar:
+ # if el.height > el.width:
+ # el.width = el.height / max_ar
+ # else:
+ # el.height = el.width / max_ar
+ # Orient the ellipse such that it encompasses most points
+ # n_inside = el.contains_points(np.c_[self.x, self.y]).sum()
+ # el.theta += 0.5 * np.pi
+ # if el.contains_points(np.c_[self.x, self.y]).sum() < n_inside:
+ # el.theta -= 0.5 * np.pi
+ return el
+ return None
+
+ @staticmethod
+ @jit(nopython=True)
+ def _fit(x, y):
+ """
+ Least Squares ellipse fitting algorithm
+ Fit an ellipse to a set of X- and Y-coordinates.
+ See Halir and Flusser, 1998 for implementation details
+
+ :param x: ndarray, 1D trajectory
+ :param y: ndarray, 1D trajectory
+ :return: 1D ndarray of 6 coefficients of the general quadratic curve:
+ ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0
+ """
+ D1 = np.vstack((x * x, x * y, y * y))
+ D2 = np.vstack((x, y, np.ones_like(x)))
+ S1 = D1 @ D1.T
+ S2 = D1 @ D2.T
+ S3 = D2 @ D2.T
+ T = -np.linalg.inv(S3) @ S2.T
+ temp = S1 + S2 @ T
+ M = np.zeros_like(temp)
+ M[0] = temp[2] * 0.5
+ M[1] = -temp[1]
+ M[2] = temp[0] * 0.5
+ E, V = np.linalg.eig(M)
+ cond = 4 * V[0] * V[2] - V[1] ** 2
+ a1 = V[:, cond > 0][:, 0]
+ a2 = T @ a1
+ return np.hstack((a1, a2))
+
+ @staticmethod
+ @jit(nopython=True)
+ def _fit_error(x, y, sd):
+ """
+ Fit a sd-sigma covariance error ellipse to the data.
+
+ :param x: ndarray, 1D input of X coordinates
+ :param y: ndarray, 1D input of Y coordinates
+ :param sd: int, size of the error ellipse in 'standard deviation'
+ :return: ellipse center, semi-axes length, angle to the X-axis
+ """
+ cov = np.cov(x, y)
+ E, V = np.linalg.eigh(cov) # Returns the eigenvalues in ascending order
+ # r2 = chi2.ppf(2 * norm.cdf(sd) - 1, 2)
+ # height, width = np.sqrt(E * r2)
+ height, width = 2 * sd * np.sqrt(E)
+ a, b = V[:, 1]
+ rotation = math.atan2(b, a) % np.pi
+ return [np.mean(x), np.mean(y), width, height, rotation]
+
+ @staticmethod
+ @jit(nopython=True)
+ def calc_parameters(coeffs):
+ """
+ Calculate ellipse center coordinates, semi-axes lengths, and
+ the counterclockwise angle of rotation from the x-axis to the ellipse major axis.
+ Visit http://mathworld.wolfram.com/Ellipse.html
+ for how to estimate ellipse parameters.
+
+ :param coeffs: list of fitting coefficients
+ :return: center: 1D ndarray, semi-axes: 1D ndarray, angle: float
+ """
+ # The general quadratic curve has the form:
+ # ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0
+ a, b, c, d, f, g = coeffs
+ b *= 0.5
+ d *= 0.5
+ f *= 0.5
+
+ # Ellipse center coordinates
+ x0 = (c * d - b * f) / (b * b - a * c)
+ y0 = (a * f - b * d) / (b * b - a * c)
+
+ # Semi-axes lengths
+ num = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
+ den1 = (b * b - a * c) * (np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c))
+ den2 = (b * b - a * c) * (-np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c))
+ major = np.sqrt(num / den1)
+ minor = np.sqrt(num / den2)
+
+ # Angle to the horizontal
+ if b == 0:
+ if a < c:
+ phi = 0
+ else:
+ phi = np.pi / 2
+ else:
+ if a < c:
+ phi = np.arctan(2 * b / (a - c)) / 2
+ else:
+ phi = np.pi / 2 + np.arctan(2 * b / (a - c)) / 2
+
+ return [x0, y0, 2 * major, 2 * minor, phi]
+
+
+class EllipseTracker(BaseTracker):
+ def __init__(self, params):
+ super().__init__(dim=5, dim_z=5)
+ self.kf.R[2:, 2:] *= 10.0
+ # High uncertainty to the unobservable initial velocities
+ self.kf.P[5:, 5:] *= 1000.0
+ self.kf.P *= 10.0
+ self.kf.Q[5:, 5:] *= 0.01
+ self.state = params
+
+ @BaseTracker.state.setter
+ def state(self, params):
+ state = np.asarray(params).reshape((-1, 1))
+ super(EllipseTracker, type(self)).state.fset(self, state)
+
+
+class SkeletonTracker(BaseTracker):
+ def __init__(self, n_bodyparts):
+ super().__init__(dim=n_bodyparts * 2, dim_z=n_bodyparts)
+ self.kf.Q[self.kf.dim_z :, self.kf.dim_z :] *= 10
+ self.kf.R[self.kf.dim_z :, self.kf.dim_z :] *= 0.01
+ self.kf.P[self.kf.dim_z :, self.kf.dim_z :] *= 1000
+
+ def update(self, pose):
+ flat = pose.reshape((-1, 1))
+ empty = np.isnan(flat).squeeze()
+ if empty.any():
+ H = self.kf.H.copy()
+ H[empty] = 0
+ flat[empty] = 0
+ self.kf.update(flat, H=H)
+ else:
+ super().update(flat)
+
+ @BaseTracker.state.setter
+ def state(self, pose):
+ curr_pose = pose.copy()
+ empty = np.isnan(curr_pose).all(axis=1)
+ if empty.any():
+ fill = np.nanmean(pose, axis=0)
+ curr_pose[empty] = fill
+ super(SkeletonTracker, type(self)).state.fset(self, curr_pose.reshape((-1, 1)))
+
+
+class BoxTracker(BaseTracker):
+ def __init__(self, bbox):
+ super().__init__(dim=4, dim_z=4)
+ self.kf = KalmanFilter(dim_x=7, dim_z=4)
+ self.kf.F = np.array(
+ [
+ [1, 0, 0, 0, 1, 0, 0],
+ [0, 1, 0, 0, 0, 1, 0],
+ [0, 0, 1, 0, 0, 0, 1],
+ [0, 0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0, 0],
+ [0, 0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0, 1],
+ ]
+ )
+ self.kf.H = np.array(
+ [
+ [1, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0],
+ [0, 0, 1, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0, 0],
+ ]
+ )
+ self.kf.R[2:, 2:] *= 10.0
+ # Give high uncertainty to the unobservable initial velocities
+ self.kf.P[4:, 4:] *= 1000.0
+ self.kf.P *= 10.0
+ self.kf.Q[-1, -1] *= 0.01
+ self.kf.Q[4:, 4:] *= 0.01
+ self.state = bbox
+
+ def update(self, bbox):
+ super().update(self.convert_bbox_to_z(bbox))
+
+ def predict(self):
+ if (self.kf.x[6] + self.kf.x[2]) <= 0:
+ self.kf.x[6] *= 0.0
+ return super().predict()
+
+ @property
+ def state(self):
+ return self.convert_x_to_bbox(self.kf.x)[0]
+
+ @state.setter
+ def state(self, bbox):
+ state = self.convert_bbox_to_z(bbox)
+ super(BoxTracker, type(self)).state.fset(self, state)
+
+ @staticmethod
+ def convert_x_to_bbox(x, score=None):
+ """
+ Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
+ [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
+ """
+ w = np.sqrt(x[2] * x[3])
+ h = x[2] / w
+ if score is None:
+ return np.array(
+ [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]
+ ).reshape((1, 4))
+ else:
+ return np.array(
+ [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]
+ ).reshape((1, 5))
+
+ @staticmethod
+ def convert_bbox_to_z(bbox):
+ """
+ Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
+ [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
+ the aspect ratio
+ """
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ x = bbox[0] + w / 2.0
+ y = bbox[1] + h / 2.0
+ s = w * h # scale is just area
+ r = w / float(h)
+ return np.array([x, y, s, r]).reshape((4, 1))
+
+
+class SORTBase(metaclass=abc.ABCMeta):
+ def __init__(self):
+ self.n_frames = 0
+ self.trackers = []
+
+ @abc.abstractmethod
+ def track(self):
+ pass
+
+
+class SORTEllipse(SORTBase):
+ def __init__(self, max_age, min_hits, iou_threshold, sd=2):
+ self.max_age = max_age
+ self.min_hits = min_hits
+ self.iou_threshold = iou_threshold
+ self.fitter = EllipseFitter(sd)
+ EllipseTracker.n_trackers = 0
+ super().__init__()
+
+ def track(self, poses, identities=None):
+ self.n_frames += 1
+
+ trackers = np.zeros((len(self.trackers), 6))
+ for i in range(len(trackers)):
+ trackers[i, :5] = self.trackers[i].predict()
+ empty = np.isnan(trackers).any(axis=1)
+ trackers = trackers[~empty]
+ for ind in np.flatnonzero(empty)[::-1]:
+ self.trackers.pop(ind)
+
+ ellipses = []
+ pred_ids = []
+ for i, pose in enumerate(poses):
+ el = self.fitter.fit(pose)
+ if el is not None:
+ ellipses.append(el)
+ if identities is not None:
+ pred_ids.append(mode(identities[i])[0][0])
+ if not len(trackers):
+ matches = np.empty((0, 2), dtype=int)
+ unmatched_detections = np.arange(len(ellipses))
+ unmatched_trackers = np.empty((0, 6), dtype=int)
+ else:
+ ellipses_trackers = [Ellipse(*t[:5]) for t in trackers]
+ cost_matrix = np.zeros((len(ellipses), len(ellipses_trackers)))
+ for i, el in enumerate(ellipses):
+ for j, el_track in enumerate(ellipses_trackers):
+ cost = el.calc_similarity_with(el_track)
+ if identities is not None:
+ match = 2 if pred_ids[i] == self.trackers[j].id_ else 1
+ cost *= match
+ cost_matrix[i, j] = cost
+ row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=True)
+ unmatched_detections = [
+ i for i, _ in enumerate(ellipses) if i not in row_indices
+ ]
+ unmatched_trackers = [
+ j for j, _ in enumerate(trackers) if j not in col_indices
+ ]
+ matches = []
+ for row, col in zip(row_indices, col_indices):
+ val = cost_matrix[row, col]
+ # diff = val - cost_matrix
+ # diff[row, col] += val
+ # if (
+ # val < self.iou_threshold
+ # or np.any(diff[row] <= 0.2)
+ # or np.any(diff[:, col] <= 0.2)
+ # ):
+ if val < self.iou_threshold:
+ unmatched_detections.append(row)
+ unmatched_trackers.append(col)
+ else:
+ matches.append([row, col])
+ if not len(matches):
+ matches = np.empty((0, 2), dtype=int)
+ else:
+ matches = np.stack(matches)
+ unmatched_trackers = np.asarray(unmatched_trackers)
+ unmatched_detections = np.asarray(unmatched_detections)
+
+ animalindex = []
+ for t, tracker in enumerate(self.trackers):
+ if t not in unmatched_trackers:
+ ind = matches[matches[:, 1] == t, 0][0]
+ animalindex.append(ind)
+ tracker.update(ellipses[ind].parameters)
+ else:
+ animalindex.append(-1)
+
+ for i in unmatched_detections:
+ trk = EllipseTracker(ellipses[i].parameters)
+ if identities is not None:
+ trk.id_ = mode(identities[i])[0][0]
+ self.trackers.append(trk)
+ animalindex.append(i)
+
+ i = len(self.trackers)
+ ret = []
+ for trk in reversed(self.trackers):
+ d = trk.state
+ if (trk.time_since_update < 1) and (
+ trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits
+ ):
+ ret.append(
+ np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape(
+ 1, -1
+ )
+ ) # for DLC we also return the original animalid
+ # +1 as MOT benchmark requires positive >> this is removed for DLC!
+ i -= 1
+ # remove dead tracklet
+ if trk.time_since_update > self.max_age:
+ self.trackers.pop(i)
+
+ if len(ret) > 0:
+ return np.concatenate(ret)
+ return np.empty((0, 7))
+
+
+class SORTSkeleton(SORTBase):
+ def __init__(self, n_bodyparts, max_age=20, min_hits=3, oks_threshold=0.5):
+ self.n_bodyparts = n_bodyparts
+ self.max_age = max_age
+ self.min_hits = min_hits
+ self.oks_threshold = oks_threshold
+ SkeletonTracker.n_trackers = 0
+ super().__init__()
+
+ @staticmethod
+ def weighted_hausdorff(x, y):
+ # Modified from scipy source code:
+ # - to restrict its use to 2D
+ # - to get rid of shuffling (since arrays are only (nbodyparts * 3) element long)
+ # TODO - factor in keypoint confidence (and weight by # of observations??)
+ cmax = 0
+ for i in range(x.shape[0]):
+ no_break_occurred = True
+ cmin = np.inf
+ for j in range(y.shape[0]):
+ d = (x[i, 0] - y[j, 0]) ** 2 + (x[i, 1] - y[j, 1]) ** 2
+ if d < cmax:
+ no_break_occurred = False
+ break
+ if d < cmin:
+ cmin = d
+ if cmin != np.inf and cmin > cmax and no_break_occurred:
+ cmax = cmin
+ return np.sqrt(cmax)
+
+ @staticmethod
+ def object_keypoint_similarity(x, y):
+ mask = ~np.isnan(x * y).all(axis=1) # Intersection visible keypoints
+ xx = x[mask]
+ yy = y[mask]
+ dist = np.linalg.norm(xx - yy, axis=1)
+ scale = np.sqrt(
+ np.product(np.ptp(yy, axis=0))
+ ) # square root of bounding box area
+ oks = np.exp(-0.5 * (dist / (0.05 * scale)) ** 2)
+ return np.mean(oks)
+
+ def calc_pairwise_hausdorff_dist(self, poses, poses_ref):
+ mat = np.zeros((len(poses), len(poses_ref)))
+ for i, pose in enumerate(poses):
+ for j, pose_ref in enumerate(poses_ref):
+ mat[i, j] = self.weighted_hausdorff(pose, pose_ref)
+ return mat
+
+ def calc_pairwise_oks(self, poses, poses_ref):
+ mat = np.zeros((len(poses), len(poses_ref)))
+ for i, pose in enumerate(poses):
+ for j, pose_ref in enumerate(poses_ref):
+ mat[i, j] = self.object_keypoint_similarity(pose, pose_ref)
+ return mat
+
+ def track(self, poses):
+ self.n_frames += 1
+
+ if not len(self.trackers):
+ for pose in poses:
+ tracker = SkeletonTracker(self.n_bodyparts)
+ tracker.state = pose
+ self.trackers.append(tracker)
+
+ poses_ref = []
+ for i, tracker in enumerate(self.trackers):
+ pose_ref = tracker.predict()
+ poses_ref.append(pose_ref.reshape((-1, 2)))
+
+ # mat = self.calc_pairwise_oks(poses, poses_ref)
+ mat = self.calc_pairwise_hausdorff_dist(poses, poses_ref)
+ row_indices, col_indices = linear_sum_assignment(mat, maximize=False)
+
+ unmatched_poses = [p for p, _ in enumerate(poses) if p not in row_indices]
+ unmatched_trackers = [
+ t for t, _ in enumerate(poses_ref) if t not in col_indices
+ ]
+ # Remove matched detections with low OKS
+ # matches = []
+ # for row, col in zip(row_indices, col_indices):
+ # if mat[row, col] < self.oks_threshold:
+ # unmatched_poses.append(row)
+ # unmatched_trackers.append(col)
+ # else:
+ # matches.append([row, col])
+ # if not len(matches):
+ # matches = np.empty((0, 2), dtype=int)
+ # else:
+ # matches = np.stack(matches)
+ matches = np.c_[row_indices, col_indices]
+
+ animalindex = []
+ for t, tracker in enumerate(self.trackers):
+ if t not in unmatched_trackers:
+ ind = matches[matches[:, 1] == t, 0][0]
+ animalindex.append(ind)
+ tracker.update(poses[ind])
+ else:
+ animalindex.append(-1)
+
+ for i in unmatched_poses:
+ tracker = SkeletonTracker(self.n_bodyparts)
+ tracker.state = poses[i]
+ self.trackers.append(tracker)
+ animalindex.append(i)
+
+ states = []
+ i = len(self.trackers)
+ for tracker in reversed(self.trackers):
+ i -= 1
+ if tracker.time_since_update > self.max_age:
+ self.trackers.pop()
+ continue
+ state = tracker.predict()
+ states.append(np.r_[state, [tracker.id, int(animalindex[i])]])
+ if len(states) > 0:
+ return np.stack(states)
+ return np.empty((0, self.n_bodyparts * 2 + 2))
+
+
+class SORTBox(SORTBase):
+ def __init__(self, max_age, min_hits, iou_threshold):
+ self.max_age = max_age
+ self.min_hits = min_hits
+ self.iou_threshold = iou_threshold
+ BoxTracker.n_trackers = 0
+ super().__init__()
+
+ def track(self, dets):
+ self.n_frames += 1
+
+ trackers = np.zeros((len(self.trackers), 5))
+ for i in range(len(trackers)):
+ trackers[i, :4] = self.trackers[i].predict()
+ empty = np.isnan(trackers).any(axis=1)
+ trackers = trackers[~empty]
+ for ind in np.flatnonzero(empty)[::-1]:
+ self.trackers.pop(ind)
+
+ matched, unmatched_dets, unmatched_trks = self.match_detections_to_trackers(
+ dets, trackers, self.iou_threshold
+ )
+
+ # update matched trackers with assigned detections
+ animalindex = []
+ for t, trk in enumerate(self.trackers):
+ if t not in unmatched_trks:
+ d = matched[np.where(matched[:, 1] == t)[0], 0]
+ animalindex.append(d[0])
+ trk.update(dets[d, :][0]) # update coordinates
+ else:
+ animalindex.append("nix") # lost trk!
+
+ # create and initialise new trackers for unmatched detections
+ for i in unmatched_dets:
+ trk = BoxTracker(dets[i, :])
+ self.trackers.append(trk)
+ animalindex.append(i)
+
+ i = len(self.trackers)
+ ret = []
+ for trk in reversed(self.trackers):
+ d = trk.state
+ if (trk.time_since_update < 1) and (
+ trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits
+ ):
+ ret.append(
+ np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape(
+ 1, -1
+ )
+ ) # for DLC we also return the original animalid
+ # +1 as MOT benchmark requires positive >> this is removed for DLC!
+ i -= 1
+ # remove dead tracklet
+ if trk.time_since_update > self.max_age:
+ self.trackers.pop(i)
+
+ if len(ret) > 0:
+ return np.concatenate(ret)
+ return np.empty((0, 5))
+
+ @staticmethod
+ def match_detections_to_trackers(detections, trackers, iou_threshold):
+ """
+ Assigns detections to tracked object (both represented as bounding boxes)
+
+ Returns 3 lists of matches, unmatched_detections and unmatched_trackers
+ """
+ if not len(trackers):
+ return (
+ np.empty((0, 2), dtype=int),
+ np.arange(len(detections)),
+ np.empty((0, 5), dtype=int),
+ )
+ iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
+
+ for d, det in enumerate(detections):
+ for t, trk in enumerate(trackers):
+ iou_matrix[d, t] = calc_iou(det, trk)
+ row_indices, col_indices = linear_sum_assignment(-iou_matrix)
+
+ unmatched_detections = []
+ for d, det in enumerate(detections):
+ if d not in row_indices:
+ unmatched_detections.append(d)
+ unmatched_trackers = []
+ for t, trk in enumerate(trackers):
+ if t not in col_indices:
+ unmatched_trackers.append(t)
+
+ # filter out matched with low IOU
+ matches = []
+ for row, col in zip(row_indices, col_indices):
+ if iou_matrix[row, col] < iou_threshold:
+ unmatched_detections.append(row)
+ unmatched_trackers.append(col)
+ else:
+ matches.append([row, col])
+ if not len(matches):
+ matches = np.empty((0, 2), dtype=int)
+ else:
+ matches = np.stack(matches)
+ return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
+
+
+def fill_tracklets(tracklets, trackers, animals, imname):
+ for content in trackers:
+ tracklet_id, pred_id = content[-2:].astype(int)
+ if tracklet_id not in tracklets:
+ tracklets[tracklet_id] = {}
+ if pred_id != -1:
+ tracklets[tracklet_id][imname] = np.asarray(animals[pred_id])
+ else: # Resort to the tracker prediction
+ xy = np.asarray(content[:-2])
+ pred = np.insert(xy, range(2, len(xy) + 1, 2), 1)
+ tracklets[tracklet_id][imname] = np.asarray(pred)
+
+
+def calc_bboxes_from_keypoints(data, slack=0, offset=0):
+ data = np.asarray(data)
+ if data.shape[-1] < 3:
+ raise ValueError("Data should be of shape (n_animals, n_bodyparts, 3)")
+
+ if data.ndim != 3:
+ data = np.expand_dims(data, axis=0)
+ bboxes = np.full((data.shape[0], 5), np.nan)
+ bboxes[:, :2] = np.nanmin(data[..., :2], axis=1) - slack # X1, Y1
+ bboxes[:, 2:4] = np.nanmax(data[..., :2], axis=1) + slack # X2, Y2
+ bboxes[:, -1] = np.nanmean(data[..., 2]) # Average confidence
+ bboxes[:, [0, 2]] += offset
+ return bboxes
+
+
+def reconstruct_all_ellipses(data, sd):
+ xy = data.droplevel("scorer", axis=1).drop("likelihood", axis=1, level=-1)
+ if "single" in xy:
+ xy.drop("single", axis=1, level="individuals", inplace=True)
+ animals = xy.columns.get_level_values("individuals").unique()
+ nrows = xy.shape[0]
+ ellipses = np.full((len(animals), nrows, 5), np.nan)
+ fitter = EllipseFitter(sd)
+ for n, animal in enumerate(animals):
+ data = xy.xs(animal, axis=1, level="individuals").values.reshape((nrows, -1, 2))
+ for i, coords in enumerate(tqdm(data)):
+ el = fitter.fit(coords.astype(np.float64))
+ if el is not None:
+ ellipses[n, i] = el.parameters
+ return ellipses
+
+
+def _track_individuals(
+ individuals, min_hits=1, max_age=5, similarity_threshold=0.6, track_method="ellipse"
+):
+ if track_method not in TRACK_METHODS:
+ raise ValueError(f"Unknown {track_method} tracker.")
+
+ if track_method == "ellipse":
+ tracker = SORTEllipse(max_age, min_hits, similarity_threshold)
+ elif track_method == "box":
+ tracker = SORTBox(max_age, min_hits, similarity_threshold)
+ else:
+ n_bodyparts = individuals[0][0].shape[0]
+ tracker = SORTSkeleton(n_bodyparts, max_age, min_hits, similarity_threshold)
+
+ tracklets = defaultdict(dict)
+ all_hyps = dict()
+ for i, (multi, single) in enumerate(tqdm(individuals)):
+ if single is not None:
+ tracklets["single"][i] = single
+ if multi is None:
+ continue
+ if track_method == "box":
+ # TODO: get cropping parameters and utilize!
+ xy = calc_bboxes_from_keypoints(multi)
+ else:
+ xy = multi[..., :2]
+ hyps = tracker.track(xy)
+ all_hyps[i] = hyps
+ for hyp in hyps:
+ tracklet_id, pred_id = hyp[-2:].astype(int)
+ if pred_id != -1:
+ tracklets[tracklet_id][i] = multi[pred_id]
+ return tracklets, all_hyps
diff --git a/deeplabcut/core/visualization.py b/deeplabcut/core/visualization.py
new file mode 100644
index 0000000000..1911a4f769
--- /dev/null
+++ b/deeplabcut/core/visualization.py
@@ -0,0 +1,236 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Visualization methods for """
+from __future__ import annotations
+
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def form_figure(nx, ny) -> tuple[plt.Figure, plt.Axes]:
+ """Forms a figure on which to plot images"""
+ fig, ax = plt.subplots(frameon=False)
+ ax.set_xlim(0, nx)
+ ax.set_ylim(0, ny)
+ ax.axis("off")
+ ax.invert_yaxis()
+ fig.tight_layout()
+ return fig, ax
+
+
+def visualize_scoremaps(
+ image: np.ndarray,
+ scmap: np.ndarray,
+) -> tuple[plt.Figure, plt.Axes]:
+ """Plots scoremaps as an image overlay.
+
+ Args:
+ image: An image as a numpy array of shape (h, w, channels)
+ scmap: A scoremap of shape (h, w)
+
+ Returns:
+ The figure and axis on which the image scoremap was plot.
+ """
+ ny, nx = np.shape(image)[:2]
+ fig, ax = form_figure(nx, ny)
+ ax.imshow(image)
+ ax.imshow(scmap, alpha=0.5)
+ return fig, ax
+
+
+def visualize_locrefs(
+ image: np.ndarray,
+ scmap: np.ndarray,
+ locref_x: np.ndarray,
+ locref_y: np.ndarray,
+ step: int = 5,
+ zoom_width: int = 0,
+) -> tuple[plt.Figure, plt.Axes]:
+ """Plots a scoremap and the corresponding location refinement field on an image.
+
+ Args:
+ image: An image as a numpy array of shape (h, w, channels)
+ scmap: A scoremap of shape (h, w)
+ locref_x: The x-coordinate of the location refinement field, of shape (h, w)
+ locref_y: The y-coordinate of the location refinement field, of shape (h, w)
+ step: The step with which to plot the location refinement field.
+ zoom_width: The zoom width with which to plot the scoremaps.
+
+ Returns:
+ The figure and axis on which the image scoremap and locref field were plot.
+ """
+ fig, ax = visualize_scoremaps(image, scmap)
+ X, Y = np.meshgrid(np.arange(locref_x.shape[1]), np.arange(locref_x.shape[0]))
+ M = np.zeros(locref_x.shape, dtype=bool)
+ M[scmap < 0.5] = True
+ U = np.ma.masked_array(locref_x, mask=M)
+ V = np.ma.masked_array(locref_y, mask=M)
+ ax.quiver(
+ X[::step, ::step],
+ Y[::step, ::step],
+ U[::step, ::step],
+ V[::step, ::step],
+ color="r",
+ units="x",
+ scale_units="xy",
+ scale=1,
+ angles="xy",
+ )
+ if zoom_width > 0:
+ maxloc = np.unravel_index(np.argmax(scmap), scmap.shape)
+ ax.set_xlim(maxloc[1] - zoom_width, maxloc[1] + zoom_width)
+ ax.set_ylim(maxloc[0] + zoom_width, maxloc[0] - zoom_width)
+ return fig, ax
+
+
+def visualize_paf(
+ image: np.ndarray,
+ paf: np.ndarray,
+ step: int = 5,
+ colors: list | None = None,
+) -> tuple[plt.Figure, plt.Axes]:
+ """Plots the PAF on top of the image.
+
+ Args:
+ image: Shape (height, width, channels). The image on which the model was run.
+ paf: Shape (height, width, 2 * len(paf_graph)). The PAF output by the model.
+ step: The step with which to plot the scoremaps.
+ colors: The colormap to use.
+
+ Returns:
+ The figure and axis on which the image PAF was plot.
+ """
+ ny, nx = np.shape(image)[:2]
+ fig, ax = form_figure(nx, ny)
+ ax.imshow(image)
+ n_fields = paf.shape[2]
+ if colors is None:
+ colors = ["r"] * n_fields
+ for n in range(n_fields):
+ U = paf[:, :, n, 0]
+ V = paf[:, :, n, 1]
+ X, Y = np.meshgrid(np.arange(U.shape[1]), np.arange(U.shape[0]))
+ M = np.zeros(U.shape, dtype=bool)
+ M[U**2 + V**2 < 0.5 * 0.5**2] = True
+ U = np.ma.masked_array(U, mask=M)
+ V = np.ma.masked_array(V, mask=M)
+ ax.quiver(
+ X[::step, ::step],
+ Y[::step, ::step],
+ U[::step, ::step],
+ V[::step, ::step],
+ scale=50,
+ headaxislength=4,
+ alpha=1,
+ width=0.002,
+ color=colors[n],
+ angles="xy",
+ )
+ return fig, ax
+
+
+def generate_model_output_plots(
+ output_folder: Path,
+ image_name: str,
+ bodypart_names: list[str],
+ bodyparts_to_plot: list[str],
+ image: np.ndarray,
+ scmap: np.ndarray,
+ locref: np.ndarray | None = None,
+ paf: np.ndarray | None = None,
+ paf_graph: list[tuple[int, int]] | None = None,
+ paf_all_in_one: bool = True,
+ paf_colormap: str = "rainbow",
+ output_suffix: str = "",
+) -> None:
+ """Generates model output plots (maps) for an image and saves them to disk.
+
+ Args:
+ output_folder: The folder in which the plots should be saved.
+ image_name: The name of the image for which the plots were generated.
+ bodypart_names: The names of bodyparts the model outputs.
+ bodyparts_to_plot: The names of bodyparts that should be plot.
+ image: Shape (height, width, channels). The image on which the model was run.
+ scmap: Shape (height, width, num_bodyparts). The scoremaps output by the model.
+ locref: Shape (height, width, num_bodyparts, 2). Optionally, the location
+ refinement fields output by the model.
+ paf: Shape (height, width, 2 * len(paf_graph)). Optionally, the part-affinity
+ fields output by the model.
+ paf_graph: Must be set if paf is not None. The PAF graph used to assemble.
+ paf_all_in_one: Whether to plot all PAFs in a single image.
+ paf_colormap: The colormap to use for the PAF maps.
+ output_suffix: The filename suffix for the maps to output.
+ """
+ def _filename(map_name) -> str:
+ return f"{image_name}_{map_name}_{output_suffix}.png"
+
+ to_plot = [
+ i for i, bpt in enumerate(bodypart_names) if bpt in bodyparts_to_plot
+ ]
+ if len(to_plot) > 1:
+ map_ = scmap[:, :, to_plot].sum(axis=2)
+ elif len(to_plot) == 1 and len(bodypart_names) > 1:
+ map_ = scmap[:, :, to_plot[0]]
+ else:
+ map_ = scmap[..., 0]
+
+ fig1, _ = visualize_scoremaps(image, map_)
+ fig1.savefig(output_folder / _filename("scmap"))
+
+ if locref is not None:
+ if len(to_plot) > 1:
+ map_ = scmap[:, :, to_plot]
+ locref_x_ = locref[:, :, to_plot, 0]
+ locref_y_ = locref[:, :, to_plot, 1]
+ # only get the locref fields around their respective detections
+ locref_x_[map_ < 0.5] = 0
+ locref_y_[map_ < 0.5] = 0
+ # combine locrefs
+ map_ = map_.sum(axis=2)
+ locref_x_ = locref_x_.sum(axis=2)
+ locref_y_ = locref_y_.sum(axis=2)
+ elif len(to_plot) == 1 and len(bodypart_names) > 1:
+ locref_x_ = locref[:, :, to_plot[0], 0]
+ locref_y_ = locref[:, :, to_plot[0], 1]
+ else:
+ locref_x_ = locref[..., 0]
+ locref_y_ = locref[..., 1]
+
+ fig2, _ = visualize_locrefs(image, map_, locref_x_, locref_y_)
+ fig2.savefig(output_folder / _filename("locref"))
+
+ if paf is not None:
+ if paf_graph is None:
+ raise ValueError(f"When plotting the PAF, you must pass the ``paf_graph``")
+
+ edge_list = []
+ for n, edge in enumerate(paf_graph):
+ if any(ind in to_plot for ind in edge):
+ e0, e1 = edge
+ edge_list.append(
+ [(2 * n, 2 * n + 1), (bodypart_names[e0], bodypart_names[e1])]
+ )
+
+ if paf_all_in_one:
+ inds = [elem[0] for elem in edge_list]
+ n_inds = len(inds)
+ cmap = plt.cm.get_cmap(paf_colormap, n_inds)
+ colors = cmap(range(n_inds))
+ fig3, _ = visualize_paf(image, paf[:, :, inds], colors=colors)
+ fig3.savefig(output_folder / _filename("paf"))
+ else:
+ for inds, names in edge_list:
+ fig3, _ = visualize_paf(image, paf[:, :, [inds]])
+ fig3.savefig(output_folder / _filename(f"paf_{'_'.join(names)}"))
+
+ plt.close("all")
diff --git a/deeplabcut/core/weight_init.py b/deeplabcut/core/weight_init.py
new file mode 100644
index 0000000000..8a6d374e8a
--- /dev/null
+++ b/deeplabcut/core/weight_init.py
@@ -0,0 +1,206 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Classes to configure how to initialize model weights"""
+from __future__ import annotations
+
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+
+import numpy as np
+
+
+@dataclass
+class WeightInitialization:
+ """Configures weights initialization when transfer learning or fine-tuning models
+
+ Args:
+ snapshot_path: The path to the snapshot used to initialize pose model weights
+ when training a model.
+ detector_snapshot_path: The path to the snapshot used to initialize detector
+ weights when training a model.
+ dataset: Optionally, the dataset on which the snapshots were trained. Required
+ when fine-tuning SuperAnimal models.
+ with_decoder: Whether to load the decoder weights as well.
+ memory_replay: Only when ``with_decoder=True``. Whether to train the model with
+ memory replay, so that it predicts all SuperAnimal (or previous project)
+ bodyparts.
+ conversion_array: The mapping from SuperAnimal (or other project, on which the
+ weights were trained) to project bodyparts. Required when
+ `with_decoder=True`.
+ An array [7, 0, 1] means the project has 3 bodyparts, where the 1st bodypart
+ corresponds to the 8th bodypart in the pretrained model, the 2nd to the 1st
+ and the 3rd to the 2nd (as arrays are 0-indexed).
+ bodyparts: Optionally, the name of each bodypart entry in the conversion array.
+ """
+
+ snapshot_path: Path
+ detector_snapshot_path: Path | None = None
+ dataset: str | None = None
+ with_decoder: bool = False
+ memory_replay: bool = False
+ conversion_array: np.ndarray | None = None
+ bodyparts: list[str] | None = None
+
+ def __post_init__(self):
+ if self.memory_replay and not self.with_decoder:
+ raise ValueError(
+ "You cannot train a model with memory replay if you do not keep the "
+ "decoder layers (``with_decoder=True``), but you passed "
+ "`memory_replay=True` and `with_decoder=False`. Please change your "
+ "WeightInitialization parameters."
+ )
+
+ if self.with_decoder and self.conversion_array is None:
+ raise ValueError(
+ f"You must specify a conversion_array to initialize decoder weights "
+ f"(``with_decoder=True``)."
+ )
+
+ if self.bodyparts is not None and self.conversion_array is None:
+ raise ValueError(
+ f"Specifying bodyparts should only be done when `with_decoder=True` and"
+ f" the conversion array is specified."
+ )
+
+ if self.conversion_array is not None and self.bodyparts is not None:
+ if not len(self.conversion_array) == len(self.bodyparts):
+ raise ValueError(
+ f"There must be the same number of elements in the bodyparts list "
+ "and conv. array; found {self.bodyparts}, {self.conversion_array}"
+ )
+
+ def to_dict(self) -> dict:
+ """Returns: the weight initialization as a dict"""
+ data = dict()
+ if self.dataset is not None:
+ data["dataset"] = self.dataset
+
+ data["snapshot_path"] = str(self.snapshot_path)
+ if self.detector_snapshot_path is not None:
+ data["detector_snapshot_path"] = str(self.detector_snapshot_path)
+
+ data["with_decoder"] = self.with_decoder
+ data["memory_replay"] = self.memory_replay
+
+ if self.conversion_array is not None:
+ data["conversion_array"] = self.conversion_array.tolist()
+
+ if self.bodyparts is not None:
+ data["bodyparts"] = self.bodyparts
+
+ return data
+
+ @staticmethod
+ def from_dict(data: dict) -> "WeightInitialization":
+ if "snapshot_path" not in data:
+ return WeightInitialization.from_dict_legacy(data)
+
+ detector_snapshot_path = data.get("detector_snapshot_path")
+ if detector_snapshot_path is not None:
+ detector_snapshot_path = Path(detector_snapshot_path)
+
+ conversion_array = data.get("conversion_array")
+ if conversion_array is not None:
+ conversion_array = np.array(conversion_array, dtype=int)
+
+ return WeightInitialization(
+ snapshot_path=Path(data["snapshot_path"]),
+ detector_snapshot_path=detector_snapshot_path,
+ dataset=data.get("dataset"),
+ with_decoder=data["with_decoder"],
+ memory_replay=data["memory_replay"],
+ conversion_array=conversion_array,
+ bodyparts=data.get("bodyparts"),
+ )
+
+ @staticmethod
+ def from_dict_legacy(data: dict) -> "WeightInitialization":
+ """Deals with weight initialization that were created before 3.0.0rc5"""
+ import deeplabcut.pose_estimation_pytorch.modelzoo.utils as utils
+
+ conversion_array = data.get("conversion_array")
+ if conversion_array is not None:
+ conversion_array = np.array(conversion_array, dtype=int)
+
+ return WeightInitialization(
+ snapshot_path=utils.get_super_animal_snapshot_path(
+ dataset=data["dataset"],
+ model_name="hrnet_w32",
+ ),
+ detector_snapshot_path=utils.get_super_animal_snapshot_path(
+ dataset=data["dataset"],
+ model_name="fasterrcnn_resnet50_fpn_v2",
+ ),
+ with_decoder=data["with_decoder"],
+ memory_replay=data["memory_replay"],
+ conversion_array=conversion_array,
+ bodyparts=data.get("bodyparts"),
+ )
+
+ @staticmethod
+ def build(
+ cfg: dict,
+ super_animal: str,
+ model_name: str = "hrnet_w32",
+ detector_name: str = "fasterrcnn_resnet50_fpn_v2",
+ with_decoder: bool = False,
+ memory_replay: bool = False,
+ customized_pose_checkpoint: str | None = None,
+ customized_detector_checkpoint: str | None = None,
+ ) -> "WeightInitialization":
+ """Builds a WeightInitialization for a project
+
+ `WeightInitialization.build` is deprecated and will be removed in a future
+ version of DeepLabCut. Please use `build_weight_init` from `deeplabcut.modelzoo`
+ instead.
+
+ Args:
+ cfg: The project's configuration.
+ super_animal: The SuperAnimal model with which to initialize weights.
+ model_name: The name of the model architecture for which to load the weights
+ (defaults to "hrnet_w32" for backwards compatibility).
+ detector_name: The name of the detector architecture for which to load the
+ weights (defaults to "fasterrcnn_resnet50_fpn_v2" for backwards
+ compatibility).
+ with_decoder: Whether to load the decoder weights as well. If this is true,
+ a conversion table must be specified for the given SuperAnimal in the
+ project configuration file. See
+ ``deeplabcut.modelzoo.utils.create_conversion_table`` to create a
+ conversion table.
+ memory_replay: Only when ``with_decoder=True``. Whether to train the model
+ with memory replay, so that it predicts all SuperAnimal bodyparts.
+ customized_pose_checkpoint: A customized SuperAnimal pose checkpoint, as an
+ alternative to the Hugging Face one
+ customized_detector_checkpoint: A customized SuperAnimal detector
+ checkpoint, as an alternative to the Hugging Face one
+
+ Returns:
+ The built WeightInitialization.
+ """
+ from deeplabcut.modelzoo import build_weight_init
+ deprecation_warning = (
+ "The `WeightInitialization.build` is deprecated and will be removed in a "
+ "future version of DeepLabCut. Please use `build_weight_init` from "
+ "`deeplabcut.modelzoo` instead."
+ )
+ warnings.warn(deprecation_warning, DeprecationWarning)
+
+ return build_weight_init(
+ cfg,
+ super_animal,
+ model_name,
+ detector_name,
+ with_decoder,
+ memory_replay,
+ customized_pose_checkpoint,
+ customized_detector_checkpoint,
+ )
diff --git a/deeplabcut/create_project/demo_data.py b/deeplabcut/create_project/demo_data.py
index c495c89ba6..9e53eb7770 100644
--- a/deeplabcut/create_project/demo_data.py
+++ b/deeplabcut/create_project/demo_data.py
@@ -13,10 +13,15 @@
from pathlib import Path
import deeplabcut
+from deeplabcut.core.engine import Engine
from deeplabcut.utils import auxiliaryfunctions
-def load_demo_data(config, createtrainingset=True):
+def load_demo_data(
+ config: str,
+ createtrainingset: bool = True,
+ engine: Engine = Engine.PYTORCH,
+):
"""
Loads the demo data -- subset from trail-tracking data in Mathis et al. 2018.
When loading, it sets paths correctly to run this project on your system
@@ -29,6 +34,9 @@ def load_demo_data(config, createtrainingset=True):
createtrainingset : bool
Boolean variable indicating if a training set shall be created.
+ engine: Engine
+ The Engine to create the training set for if a training set shall be created.
+
Example
--------
>>> deeplabcut.load_demo_data('config.yaml')
@@ -40,7 +48,7 @@ def load_demo_data(config, createtrainingset=True):
transform_data(config)
if createtrainingset:
print("Loaded, now creating training data...")
- deeplabcut.create_training_dataset(config, num_shuffles=1)
+ deeplabcut.create_training_dataset(config, num_shuffles=1, engine=engine)
def transform_data(config):
diff --git a/deeplabcut/create_project/modelzoo.py b/deeplabcut/create_project/modelzoo.py
index 76679f1867..7b853f9542 100644
--- a/deeplabcut/create_project/modelzoo.py
+++ b/deeplabcut/create_project/modelzoo.py
@@ -13,13 +13,32 @@
from pathlib import Path
import yaml
-
-import deeplabcut
-from deeplabcut.utils import auxiliaryfunctions
+from dlclibrary import get_available_detectors
from dlclibrary.dlcmodelzoo.modelzoo_download import (
download_huggingface_model,
MODELOPTIONS,
+ get_available_datasets,
+ get_available_models,
+)
+
+import deeplabcut
+from deeplabcut import Engine
+from deeplabcut.core.config import read_config_as_dict, write_config
+from deeplabcut.generate_training_dataset.metadata import (
+ TrainingDatasetMetadata,
+ ShuffleMetadata,
+ DataSplit,
+)
+from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
+ MakeInference_yaml,
+)
+from deeplabcut.modelzoo.utils import get_super_animal_project_cfg
+from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
+ add_metadata,
+ make_pytorch_test_config,
)
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import load_super_animal_config
+from deeplabcut.utils import auxiliaryfunctions
Modeloptions = MODELOPTIONS # backwards compatibility for COLAB NOTEBOOK
@@ -96,17 +115,22 @@ def create_pretrained_human_project(
def create_pretrained_project(
- project,
- experimenter,
- videos,
- model="full_human",
- working_directory=None,
- copy_videos=False,
- videotype="",
- analyzevideo=True,
- filtered=True,
- createlabeledvideo=True,
- trainFraction=None,
+ project: str,
+ experimenter: str,
+ videos: list[str],
+ model: str | None = None,
+ working_directory: str | None = None,
+ copy_videos: bool = False,
+ videotype: str = "",
+ analyzevideo: bool = True,
+ filtered: bool = True,
+ createlabeledvideo: bool = True,
+ trainFraction: float | None = None,
+ engine: Engine = Engine.PYTORCH,
+ multi_animal: bool = False,
+ individuals: list[str] | None = None,
+ net_name: str | None = None,
+ detector_name: str | None = None,
):
"""
Creates a new project directory, sub-directories and a basic configuration file.
@@ -124,43 +148,408 @@ def create_pretrained_project(
experimenter : string
String containing the name of the experimenter.
- model: string, options see http://www.mousemotorlab.org/dlc-modelzoo
- Current option and default: 'full_human' Creates a demo human project and analyzes a video with ResNet 101 weights pretrained on MPII Human Pose. This is from the DeeperCut paper
- by Insafutdinov et al. https://arxiv.org/abs/1605.03170 Please make sure to cite it too if you use this code!
+ model: string | None, default = None,
+ The model / dataset to use as basis for the project.
+ If None, the default model / dataset for the selected engine will be used.
- videos : list
+ videos : list[string]
A list of string containing the full paths of the videos to include in the project.
- working_directory : string, optional
- The directory where the project will be created. The default is the ``current working directory``; if provided, it must be a string.
+ working_directory : string, optional, default = None
+ The directory where the project will be created. If None - the current working directory will be used.
- copy_videos : bool, optional ON WINDOWS: TRUE is often necessary!
- If this is set to True, the videos are copied to the ``videos`` directory. If it is False,symlink of the videos are copied to the project/videos directory. The default is ``False``; if provided it must be either
- ``True`` or ``False``.
+ copy_videos : bool, optional, default = False,
+ If this is set to True, the videos are copied to the ``videos`` directory.
+ If it is False, symlink of the videos are copied to the project/videos directory.
+ Note: on Windows: True is often necessary!
- analyzevideo " bool, optional
- If true, then the video is analyzed and a labeled video is created. If false, then only the project will be created and the weights downloaded. You can then access them
+ analyzevideo: bool, optional
+ If true, then the video is analyzed and a labeled video is created.
+ If false, then only the project will be created and the weights downloaded.
- filtered: bool, default false
- Boolean variable indicating if filtered pose data output should be plotted rather than frame-by-frame predictions.
- Filtered version can be calculated with deeplabcut.filterpredictions
+ filtered: bool, default True
+ Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions.
+ Filtered version can be calculated with deeplabcut.filterpredictions()
- trainFraction: By default value from *new* projects. (0.95)
+ createlabeledvideo: bool, default True,
+ Specifies if a labeled video needs to be created.
+
+ trainFraction: float|None, default = None.
Fraction that will be used in dlc-model/trainingset folder name.
+ If None - default value (0.95) from new projects will be used.
+
+ engine: Engine, default Engine.PYTORCH,
+ engine on which the pretrained weights are based
+
+ multi_animal: bool = False,
+ Specifies if the project is single or multi-animal.
+ Implemented only for Pytorch-based models.
+
+ individuals: list[str] | None = None,
+ Only if multianimal is True.
+ Defines the names of the individuals.
+
+ net_name: str | None, default = None,
+ Valid only if using Pytorch engine.
+ Name of the pose model on which the superanimal dataset has been trained on.
+ If None - "hrnet_w32" will be used as default.
+
+ detector_name: str | None, default = None,
+ Valid only if using Pytorch engine.
+ Name of the detector model on which the superanimal dataset has been trained on.
+ If None - "fasterrcnn_resnet50_fpn_v2" will be used as default.
Example
--------
Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi
- >>> deeplabcut.create_pretrained_project('humanstrokestudy','Linus',['/data/videos/homosapiens1.avi'], copy_videos=False)
+ >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False)
Loading full_cat model and analyzing video "felixfeliscatus3.avi"
- >>> deeplabcut.create_pretrained_project('humanstrokestudy','Linus',['/data/videos/felixfeliscatus3.avi'], model='full_cat')
+ >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF)
Windows:
- >>> deeplabcut.create_pretrained_project('humanstrokestudy','Bill',[r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'],r'C:\yourusername\analysis\project' copy_videos=True)
+ >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True)
Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ )
+ """
+ if engine == Engine.TF:
+ return create_pretrained_project_tensorflow(
+ project=project,
+ experimenter=experimenter,
+ videos=videos,
+ model=model,
+ working_directory=working_directory,
+ copy_videos=copy_videos,
+ videotype=videotype,
+ analyzevideo=analyzevideo,
+ filtered=filtered,
+ createlabeledvideo=createlabeledvideo,
+ trainFraction=trainFraction,
+ )
+ elif engine == Engine.PYTORCH:
+ return create_pretrained_project_pytorch(
+ project=project,
+ experimenter=experimenter,
+ videos=videos,
+ dataset=model,
+ working_directory=working_directory,
+ copy_videos=copy_videos,
+ video_type=videotype,
+ analyze_video=analyzevideo,
+ filtered=filtered,
+ create_labeled_video=createlabeledvideo,
+ train_fraction=trainFraction,
+ multi_animal=multi_animal,
+ individuals=individuals,
+ net_name=net_name,
+ detector_name=detector_name,
+ )
+
+ raise NotImplementedError(f"This function is not implemented for {engine}")
+
+
+def create_pretrained_project_pytorch(
+ project: str,
+ experimenter: str,
+ videos: list[str],
+ dataset: str | None = None,
+ working_directory: str | None = None,
+ copy_videos: bool = False,
+ video_type: str | None = None,
+ analyze_video: bool = True,
+ filtered: bool = True,
+ create_labeled_video: bool = True,
+ train_fraction: float | None = None,
+ multi_animal: bool = False,
+ individuals: list[str] | None = None,
+ net_name: str | None = None,
+ detector_name: str | None = None,
+):
+ """
+ Method used specifically for Pytorch-based ModelZoo models.
+
+ Creates a new project directory, sub-directories and a basic configuration file.
+ Change its parameters to your projects need.
+
+ The project will also be initialized with a pre-trained model from the DeepLabCut model zoo!
+
+ http://modelzoo.deeplabcut.org
+
+ Parameters
+ ----------
+ project : string
+ String containing the name of the project.
+
+ experimenter : string
+ String containing the name of the experimenter.
+
+ dataset: string|None, default = None,
+ The superanimal dataset to use as basis for the project.
+ If not specified - superanimal_quadruped will be used by default.
+
+ videos : list[string]
+ A list of string containing the full paths of the videos to include in the project.
+
+ working_directory : string, optional, default = None
+ The directory where the project will be created. If None - the current working directory will be used.
+
+ copy_videos : bool, optional, default = False,
+ If this is set to True, the videos are copied to the ``videos`` directory.
+ If it is False, symlink of the videos are copied to the project/videos directory.
+ Note: on Windows: True is often necessary!
+
+ analyze_video: bool, optional
+ If true, then the video is analyzed and a labeled video is created.
+ If false, then only the project will be created and the weights downloaded.
+
+ filtered: bool, default True
+ Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions.
+ Filtered version can be calculated with deeplabcut.filterpredictions()
+
+ create_labeled_video: bool, default True
+ Specifies if a labeled video needs to be created.
+
+ train_fraction: float|None, default = None.
+ Fraction that will be used in dlc-model/trainingset folder name.
+ If None - default value (0.95) from new projects will be used.
+
+ multi_animal: bool = False,
+ Specifies if the project is single or multi-animal
+
+ individuals: list[str]|None = None,
+ Only if multianimal is True.
+ Defines the names of the individuals.
+
+ net_name: str | None, default = None,
+ Valid only if using Pytorch engine.
+ Name of the pose model on which the superanimal dataset has been trained on.
+ If None - "hrnet_w32" will be used as default.
+ detector_name: str | None, default = None,
+ Valid only if using Pytorch engine.
+ Name of the detector model on which the superanimal dataset has been trained on.
+ If None - "fasterrcnn_resnet50_fpn_v2" will be used as default.
+
+ Example
+ --------
+ Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi
+ >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False)
+
+ Loading full_cat model and analyzing video "felixfeliscatus3.avi"
+ >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF)
+
+ Windows:
+ >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True)
+ Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ )
"""
+ # Check arguments
+ if not dataset:
+ dataset = "superanimal_quadruped"
+
+ if not net_name:
+ net_name = "hrnet_w32"
+
+ # Currently, all Pytorch Superanimal models are Top-Down.
+ if not detector_name:
+ detector_name = "fasterrcnn_resnet50_fpn_v2"
+
+ if dataset not in get_available_datasets():
+ raise ValueError(
+ f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}"
+ )
+
+ if net_name not in get_available_models(dataset):
+ raise ValueError(
+ f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}"
+ )
+
+ if detector_name not in get_available_detectors(dataset):
+ raise ValueError(
+ f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}"
+ )
+
+ # Create project
+ cfg_path = deeplabcut.create_new_project(
+ project=project,
+ experimenter=experimenter,
+ videos=videos,
+ working_directory=working_directory,
+ copy_videos=copy_videos,
+ videotype=video_type,
+ multianimal=multi_animal,
+ individuals=individuals,
+ )
+
+ # Edits to do to the project config
+ cfg_edits = {}
+ if train_fraction is not None:
+ cfg_edits["TrainingFraction"] = [train_fraction]
+ super_animal_project_cfg = get_super_animal_project_cfg(dataset)
+ super_animal_bodyparts = super_animal_project_cfg.get("bodyparts")
+ super_animal_skeleton = super_animal_project_cfg.get("skeleton")
+ cfg_edits["skeleton"] = super_animal_skeleton
+ if multi_animal:
+ cfg_edits["multianimalbodyparts"] = super_animal_bodyparts
+ else:
+ cfg_edits["bodyparts"] = super_animal_bodyparts
+ auxiliaryfunctions.edit_config(cfg_path, edits=cfg_edits)
+
+ # Create the shuffle train and test directories
+ config = read_config_as_dict(cfg_path)
+ shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder(
+ trainFraction=config["TrainingFraction"][0],
+ shuffle=1,
+ cfg=config,
+ engine=Engine.PYTORCH,
+ )
+ train_dir = shuffle_dir / "train"
+ test_dir = shuffle_dir / "test"
+ train_dir.mkdir(parents=True, exist_ok=True)
+ test_dir.mkdir(parents=True, exist_ok=True)
+
+ # Download the weights and put them into appropriate directory
+ print("Downloading weights...")
+ super_animal_detector_name = f"{dataset}_{detector_name}"
+ new_detector_name = "snapshot-detector-000.pt"
+ download_huggingface_model(
+ model_name=super_animal_detector_name,
+ target_dir=str(train_dir),
+ rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name},
+ )
+ super_animal_model_name = f"{dataset}_{net_name}"
+ new_snapshot_name = "snapshot-000.pt"
+ download_huggingface_model(
+ model_name=super_animal_model_name,
+ target_dir=str(train_dir),
+ rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name},
+ )
+
+ # Create pytorch_config.yaml
+ train_cfg_path = train_dir / "pytorch_config.yaml"
+ pytorch_config = load_super_animal_config(
+ super_animal=dataset,
+ model_name=net_name,
+ detector_name=detector_name,
+ )
+ pytorch_config = add_metadata(config, pytorch_config, train_cfg_path)
+ pytorch_config["resume_training_from"] = str(train_dir / new_snapshot_name)
+ pytorch_config["detector"]["resume_training_from"] = str(
+ train_dir / new_detector_name
+ )
+ write_config(train_cfg_path, pytorch_config)
+
+ # Create test pose_cfg.yaml
+ test_cfg_path = test_dir / "pose_cfg.yaml"
+ make_pytorch_test_config(
+ model_config=pytorch_config, test_config_path=test_cfg_path, save=True
+ )
+
+ # Create inference_cfg.yaml if needed
+ if multi_animal:
+ inference_cfg_path = test_dir / "inference_cfg.yaml"
+ _create_inference_config(inference_cfg_path, config)
+
+ # Create metadata.yaml with shuffle info in training-data directory
+ _create_training_datasets_metadata(config, shuffle_dir.name, Engine.PYTORCH)
+
+ # Process the videos
+ _process_videos(
+ cfg_path=cfg_path,
+ video_type=video_type,
+ analyze_video=analyze_video,
+ filtered=filtered,
+ create_labeled_video=create_labeled_video,
+ )
+ return cfg_path, str(train_cfg_path)
+
+
+def _create_inference_config(inference_cfg_path: str | Path, project_cfg: dict):
+ inf_updates = dict(
+ minimalnumberofconnections=int(len(project_cfg["multianimalbodyparts"]) / 2),
+ topktoretain=len(project_cfg["individuals"]),
+ withid=project_cfg.get("identity", False),
+ )
+ default_inf_path = (
+ Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml"
+ )
+ MakeInference_yaml(inf_updates, inference_cfg_path, default_inf_path)
+
+
+def create_pretrained_project_tensorflow(
+ project: str,
+ experimenter: str,
+ videos: list[str],
+ model: str | None = None,
+ working_directory: str | None = None,
+ copy_videos: bool = False,
+ videotype: str = "",
+ analyzevideo: bool = True,
+ filtered: bool = True,
+ createlabeledvideo: bool = True,
+ trainFraction: float | None = None,
+):
+ """
+ Method used specifically for Tensorflow-based ModelZoo models.
+
+ Creates a new project directory, sub-directories and a basic configuration file.
+ Change its parameters to your projects need.
+
+ The project will also be initialized with a pre-trained model from the DeepLabCut model zoo!
+
+ http://modelzoo.deeplabcut.org
+
+ Parameters
+ ----------
+ project : string
+ String containing the name of the project.
+
+ experimenter : string
+ String containing the name of the experimenter.
+
+ model: string|None, default = None,
+ The model / dataset to use as basis for the project.
+ If not specified - full_human will be used by default.
+
+ videos : list[string]
+ A list of string containing the full paths of the videos to include in the project.
+
+ working_directory : string, optional, default = None
+ The directory where the project will be created. If None - the current working directory will be used.
+
+ copy_videos : bool, optional, default = False,
+ If this is set to True, the videos are copied to the ``videos`` directory.
+ If it is False, symlink of the videos are copied to the project/videos directory.
+ Note: on Windows: True is often necessary!
+
+ analyzevideo: bool, optional
+ If true, then the video is analyzed and a labeled video is created.
+ If false, then only the project will be created and the weights downloaded.
+
+ filtered: bool, default True
+ Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions.
+ Filtered version can be calculated with deeplabcut.filterpredictions()
+
+ createlabeledvideo: bool, default True
+ Specifies if a labeled video needs to be created.
+
+ trainFraction: float|None, default = None.
+ Fraction that will be used in dlc-model/trainingset folder name.
+ If None - default value (0.95) from new projects will be used.
+
+ Example
+ --------
+ Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi
+ >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False)
+
+ Loading full_cat model and analyzing video "felixfeliscatus3.avi"
+ >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF)
+
+ Windows:
+ >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True)
+ Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ )
+ """
+ if not model:
+ model = "full_human"
+
if model in MODELOPTIONS:
cwd = os.getcwd()
@@ -300,23 +689,71 @@ def create_pretrained_project(
MakeTest_pose_yaml(pose_cfg, keys2save, path_test_config)
- video_dir = os.path.join(config["project_path"], "videos")
- if analyzevideo == True:
- print("Analyzing video...")
- deeplabcut.analyze_videos(cfg, [video_dir], videotype, save_as_csv=True)
-
- if createlabeledvideo == True:
- if filtered:
- deeplabcut.filterpredictions(cfg, [video_dir], videotype)
+ _create_training_datasets_metadata(config, modelfoldername.name, Engine.TF)
- print("Plotting results...")
- deeplabcut.create_labeled_video(
- cfg, [video_dir], videotype, draw_skeleton=True, filtered=filtered
- )
- deeplabcut.plot_trajectories(cfg, [video_dir], videotype, filtered=filtered)
+ _process_videos(
+ cfg_path=cfg,
+ video_type=videotype,
+ analyze_video=analyzevideo,
+ filtered=filtered,
+ create_labeled_video=createlabeledvideo,
+ )
os.chdir(cwd)
return cfg, path_train_config
else:
return "N/A", "N/A"
+
+
+def _create_training_datasets_metadata(
+ config: dict, shuffle_dir_name: str, engine: Engine
+):
+ # First create the metadata object
+ metadata = TrainingDatasetMetadata.create(config)
+
+ # Create a new shuffle with TensorFlow engine
+ new_shuffle = ShuffleMetadata(
+ name=shuffle_dir_name,
+ train_fraction=config["TrainingFraction"][0],
+ index=1,
+ engine=engine,
+ split=DataSplit(train_indices=(), test_indices=()),
+ )
+
+ # Add the shuffle to metadata
+ metadata = metadata.add(new_shuffle)
+
+ # Save the metadata
+ metadata.save()
+
+ return metadata
+
+
+def _process_videos(
+ cfg_path: str | Path,
+ video_type: str = "",
+ analyze_video: bool = True,
+ filtered: bool = True,
+ create_labeled_video: bool = True,
+):
+ cfg_path = str(cfg_path)
+ video_dir = Path(cfg_path).parent / "videos"
+
+ if analyze_video:
+ print("Analyzing video...")
+ deeplabcut.analyze_videos(
+ cfg_path, [video_dir], videotype=video_type, save_as_csv=True
+ )
+
+ if create_labeled_video:
+ if filtered:
+ deeplabcut.filterpredictions(cfg_path, [video_dir], video_type)
+
+ print("Plotting results...")
+ deeplabcut.create_labeled_video(
+ cfg_path, [video_dir], video_type, draw_skeleton=True, filtered=filtered
+ )
+ deeplabcut.plot_trajectories(
+ cfg_path, [video_dir], video_type, filtered=filtered
+ )
diff --git a/deeplabcut/create_project/new.py b/deeplabcut/create_project/new.py
index 0fc7a06139..f812d89bb9 100644
--- a/deeplabcut/create_project/new.py
+++ b/deeplabcut/create_project/new.py
@@ -19,13 +19,14 @@
def create_new_project(
- project,
- experimenter,
- videos,
- working_directory=None,
- copy_videos=False,
- videotype="",
- multianimal=False,
+ project: str,
+ experimenter: str,
+ videos: list[str],
+ working_directory: str | None = None,
+ copy_videos: bool = False,
+ videotype: str = "",
+ multianimal: bool = False,
+ individuals: list[str] | None = None,
):
r"""Create the necessary folders and files for a new project.
@@ -58,6 +59,11 @@ def create_new_project(
multianimal: bool, optional. Default: False.
For creating a multi-animal project (introduced in DLC 2.2)
+ individuals: list[str]|None = None,
+ Relevant only if multianimal is True.
+ list of individuals to be used in the project configuration.
+ If None - defaults to ['individual1', 'individual2', 'individual3']
+
Returns
-------
str
@@ -143,7 +149,9 @@ def create_new_project(
# Check if it is a folder
if os.path.isdir(i):
vids_in_dir = [
- os.path.join(i, vp) for vp in os.listdir(i) if vp.endswith(videotype)
+ os.path.join(i, vp)
+ for vp in os.listdir(i)
+ if vp.lower().endswith(videotype)
]
vids = vids + vids_in_dir
if len(vids_in_dir) == 0:
@@ -239,7 +247,11 @@ def create_new_project(
cfg_file, ruamelFile = auxiliaryfunctions.create_config_template(multianimal)
cfg_file["multianimalproject"] = multianimal
cfg_file["identity"] = False
- cfg_file["individuals"] = ["individual1", "individual2", "individual3"]
+ cfg_file["individuals"] = (
+ individuals
+ if individuals
+ else ["individual1", "individual2", "individual3"]
+ )
cfg_file["multianimalbodyparts"] = ["bodypart1", "bodypart2", "bodypart3"]
cfg_file["uniquebodyparts"] = []
cfg_file["bodyparts"] = "MULTI!"
@@ -272,6 +284,7 @@ def create_new_project(
cfg_file["TrainingFraction"] = [0.95]
cfg_file["iteration"] = 0
cfg_file["snapshotindex"] = -1
+ cfg_file["detector_snapshotindex"] = -1
cfg_file["x1"] = 0
cfg_file["x2"] = 640
cfg_file["y1"] = 277
@@ -279,6 +292,7 @@ def create_new_project(
cfg_file["batch_size"] = (
8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242
)
+ cfg_file["detector_batch_size"] = 1
cfg_file["corner2move2"] = (50, 50)
cfg_file["move2corner"] = True
cfg_file["skeleton_color"] = "black"
diff --git a/deeplabcut/generate_training_dataset/__init__.py b/deeplabcut/generate_training_dataset/__init__.py
index 05b0092d49..60eac17c1d 100644
--- a/deeplabcut/generate_training_dataset/__init__.py
+++ b/deeplabcut/generate_training_dataset/__init__.py
@@ -13,3 +13,8 @@
from deeplabcut.generate_training_dataset.frame_extraction import *
from deeplabcut.generate_training_dataset.trainingsetmanipulation import *
from deeplabcut.generate_training_dataset.multiple_individuals_trainingsetmanipulation import *
+from deeplabcut.generate_training_dataset.metadata import (
+ DataSplit,
+ ShuffleMetadata,
+ TrainingDatasetMetadata,
+)
diff --git a/deeplabcut/generate_training_dataset/metadata.py b/deeplabcut/generate_training_dataset/metadata.py
new file mode 100644
index 0000000000..a26a5eadda
--- /dev/null
+++ b/deeplabcut/generate_training_dataset/metadata.py
@@ -0,0 +1,481 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""File containing methods to load and parse shuffle metadata"""
+from __future__ import annotations
+
+import logging
+import pickle
+import re
+from dataclasses import dataclass
+from pathlib import Path
+
+import numpy as np
+from ruamel.yaml import YAML
+
+from deeplabcut.core.engine import Engine
+from deeplabcut.utils import auxiliaryfunctions
+
+
+@dataclass(frozen=True)
+class DataSplit:
+ """Class representing the metadata for a shuffle"""
+ train_indices: tuple[int, ...]
+ test_indices: tuple[int, ...]
+
+ def __post_init__(self) -> None:
+ """
+ Raises:
+ ValueError if the indices are not sorted in increasing
+ """
+ for indices in [self.train_indices, self.test_indices]:
+ idx = np.array(indices)
+ if not np.all(idx[:-1] < idx[1:]):
+ raise RuntimeError(
+ f"The training and test indices in a data split must be sorted in "
+ f"strictly ascending order."
+ )
+
+
+@dataclass(frozen=True)
+class ShuffleMetadata:
+ """Class representing the metadata for a shuffle"""
+ name: str
+ train_fraction: float
+ index: int
+ engine: Engine
+ split: DataSplit | None
+
+ def load_split(self, cfg: dict, trainset_path: Path) -> "ShuffleMetadata":
+ """Loads the data split for this shuffle
+
+ Args:
+ cfg: the config for the DeepLabCut project
+ trainset_path: the path to the training dataset folder
+
+ Returns:
+ a new instance with the data split defined
+ """
+ _, doc_path = auxiliaryfunctions.get_data_and_metadata_filenames(
+ trainset_path, self.train_fraction, self.index, cfg
+ )
+ if not Path(doc_path).exists():
+ raise ValueError(
+ f"Could not load the metadata file for {self} as {doc_path} does not "
+ f"exist. If you deleted the shuffle, you also need to delete the "
+ f"shuffle from metadata.yaml or recreate the metadata.yaml file."
+ )
+
+ with open(doc_path, "rb") as f:
+ _, train_idx, test_idx, _ = pickle.load(f)
+ return ShuffleMetadata(
+ name=self.name,
+ train_fraction=self.train_fraction,
+ index=self.index,
+ engine=self.engine,
+ split=DataSplit(
+ train_indices=tuple(sorted([int(idx) for idx in train_idx])),
+ test_indices=tuple(sorted([int(idx) for idx in test_idx])),
+ )
+ )
+
+
+@dataclass(frozen=True)
+class TrainingDatasetMetadata:
+ """An immutable class containing the metadata for a dataset
+
+ When creating a new "training-datasets" folder (e.g., when creating the first
+ training set for a project, or when creating the first training for a given
+ iteration of a project), TrainingDatasetMetadata.create(cfg) should be called when
+ the "training-datasets" folder is still empty.
+
+ For existing projects (created with DeepLabCut < 3.0), calling
+ TrainingDatasetMetadata.create(cfg) will go over documentation data for all existing
+ shuffles in the training-datasets folder and add them to a new metadata instance.
+ All shuffles will be given Engine.TF as an engine.
+
+ Examples:
+ # Creating the metadata file for an existing project
+ config = "/data/my-dlc-project/config.yaml"
+ trainset_metadata = TrainingDatasetMetadata.create(config)
+ trainset_metadata.save()
+
+ # Adding a new shuffle to the metadata file
+ config = "/data/my-dlc-project-2008-06-17/config.yaml"
+ trainset_metadata = TrainingDatasetMetadata.load(config)
+ new_shuffle = ShuffleMetadata(
+ name="my-dlc-projectJun17-trainset60shuffle5",
+ train_fraction=0.6,
+ index=5,
+ engine=compat.Engine.PYTORCH,
+ split=DataSplit(train_indices=(1, 3, 4), test_indices=(0, 2)),
+ )
+ trainset_metadata = trainset_metadata.add(new_shuffle)
+ trainset_metadata.save() # saves to disk
+ """
+ project_config: dict
+ shuffles: tuple[ShuffleMetadata, ...]
+ file_header: tuple[str] = (
+ "# This file is automatically generated - DO NOT EDIT",
+ "# It contains the information about the shuffles created for the dataset",
+ "---",
+ )
+
+ def __post_init__(self) -> None:
+ """
+ Raises:
+ ValueError if the indices are not sorted in increasing order
+ """
+ indices = [[s.train_fraction, s.index] for s in self.shuffles]
+ for (frac1, idx1), (frac2, idx2) in zip(indices[:-1], indices[1:]):
+ if not (frac1 < frac2 or (frac1 == frac2 and idx1 < idx2)):
+ raise RuntimeError(
+ "The shuffles given must be sorted in order of ascending training "
+ f"fraction and index. Found {self.shuffles}"
+ )
+
+ def add(
+ self,
+ shuffle: ShuffleMetadata,
+ overwrite: bool = False,
+ ) -> TrainingDatasetMetadata:
+ """
+ Adds a new shuffle to the metadata file
+
+ Args:
+ shuffle: the shuffle to add
+ overwrite: if a shuffle with the same index is already stored in the
+ metadata file, whether to overwrite it
+
+ Returns:
+ A new instance of TrainingDatasetMetadata with updated shuffles
+
+ Raises:
+ ValueError: if overwrite=False and there is already a shuffle with the given
+ index in the metadata file.
+ """
+ existing_indices = [
+ s.index for s in self.shuffles if s.train_fraction == shuffle.train_fraction
+ ]
+ if shuffle.index in existing_indices:
+ if not overwrite:
+ raise RuntimeError(
+ f"Cannot add {shuffle} to the meta: a shuffle with index "
+ f"{shuffle.index} and train_fraction {shuffle.train_fraction} "
+ f"already exists: {self.shuffles}."
+ )
+
+ existing_shuffles = [
+ s
+ for s in self.shuffles
+ if (s.index != shuffle.index or s.train_fraction != shuffle.train_fraction)
+ ]
+ shuffles = existing_shuffles + [shuffle]
+ return TrainingDatasetMetadata(
+ project_config=self.project_config,
+ shuffles=tuple(sorted(shuffles, key=lambda s: (s.train_fraction, s.index))),
+ )
+
+ def get(self, trainset_index: int = 0, index: int = 0) -> ShuffleMetadata:
+ """
+ Args:
+ trainset_index: the index of the trainset fraction as defined in config.yaml
+ index: the index of the shuffle
+
+ Returns:
+ the shuffle with the given trainset index and shuffle index
+
+ Raises:
+ ValueError if the shuffle is not present in the metadata
+ """
+ train_fraction = self.project_config["TrainingFraction"][trainset_index]
+ for shuffle in self.shuffles:
+ if (
+ shuffle.train_fraction == train_fraction
+ and shuffle.index == index
+ ):
+ return shuffle
+
+ raise ValueError(
+ f"Could not find a shuffle with trainingset fraction {train_fraction} and "
+ f"index {index}"
+ )
+
+ def save(self) -> None:
+ """Saves the training dataset metadata to disk"""
+ metadata = {"shuffles": {}}
+ data_splits: dict[DataSplit, int] = {}
+ trainset_path = self.path(self.project_config).parent
+ for s in self.shuffles:
+ if s.split is None:
+ s = s.load_split(cfg=self.project_config, trainset_path=trainset_path)
+
+ split_index = data_splits.get(s.split)
+ if split_index is None:
+ split_index = len(data_splits) + 1
+ data_splits[s.split] = split_index
+
+ metadata["shuffles"][s.name] = {
+ "train_fraction": s.train_fraction,
+ "index": s.index,
+ "split": split_index,
+ "engine": s.engine.aliases[0],
+ }
+
+ with open(self.path(self.project_config), "w") as file:
+ file.write("\n".join(self.file_header) + "\n")
+ YAML().dump(metadata, file)
+
+ @staticmethod
+ def load(
+ config: str | Path | dict,
+ load_splits: bool = False,
+ ) -> TrainingDatasetMetadata:
+ """Loads the metadata from disk
+
+ Args:
+ config: the config for the DeepLabCut project (or its path)
+ load_splits: whether to load the data split for each shuffle
+ """
+ if isinstance(config, (str, Path)):
+ cfg = auxiliaryfunctions.read_config(config)
+ else:
+ cfg = config
+
+ metadata_path = TrainingDatasetMetadata.path(cfg)
+ with open(metadata_path, "r") as file:
+ metadata = YAML(typ="safe", pure=True).load(file)
+
+ shuffles = []
+ for shuffle_name, shuffle_metadata in metadata["shuffles"].items():
+ shuffle = ShuffleMetadata(
+ name=shuffle_name,
+ train_fraction=shuffle_metadata["train_fraction"],
+ index=shuffle_metadata["index"],
+ engine=Engine(shuffle_metadata["engine"]),
+ split=None,
+ )
+ if load_splits:
+ shuffle = shuffle.load_split(cfg, metadata_path.parent)
+
+ shuffles.append(shuffle)
+
+ shuffles.sort(key=lambda s: (s.train_fraction, s.index))
+ return TrainingDatasetMetadata(project_config=cfg, shuffles=tuple(shuffles))
+
+ @staticmethod
+ def create(config: str | Path | dict) -> TrainingDatasetMetadata:
+ """Function to create the metadata file
+
+ Assumes that all existing shuffles use the TensorFlow engine, as this file
+ should have already been created for PyTorch shuffles.
+
+ Args;
+ config: the config for the DeepLabCut project (or its path)
+ default_engine: the default engine to set for shuffles in the project
+
+ Returns:
+ the metadata for the existing shuffles in the project
+ """
+ if isinstance(config, (str, Path)):
+ cfg = auxiliaryfunctions.read_config(config)
+ else:
+ cfg = config
+
+ trainset_path = TrainingDatasetMetadata.path(cfg).parent
+ if trainset_path.exists():
+ shuffle_docs = [
+ f
+ for f in trainset_path.iterdir()
+ if re.match(r"Documentation_data-.+shuffle[0-9]+\.pickle", f.name)
+ ]
+ else:
+ trainset_path.mkdir(parents=True)
+ shuffle_docs = []
+
+ prefix = cfg["Task"] + cfg["date"]
+ shuffles = []
+ existing_splits: dict[tuple[tuple[int, ...], tuple[int, ...]], int] = {}
+ for doc_path in shuffle_docs:
+ index = int(doc_path.stem.split("shuffle")[-1])
+ with open(doc_path, "rb") as f:
+ _, train_idx, test_idx, train_frac = pickle.load(f)
+
+ engine = Engine.TF
+ train_idx = tuple(sorted([int(idx) for idx in train_idx]))
+ test_idx = tuple(sorted([int(idx) for idx in test_idx]))
+ split_idx = existing_splits.get((train_idx, test_idx))
+ if split_idx is None:
+ split_idx = len(existing_splits) + 1
+ existing_splits[(train_idx, test_idx)] = split_idx
+
+ shuffles.append(
+ ShuffleMetadata(
+ name=f"{prefix}-trainset{int(100 * train_frac)}shuffle{index}",
+ train_fraction=train_frac,
+ index=index,
+ engine=engine,
+ split=DataSplit(train_indices=train_idx, test_indices=test_idx),
+ )
+ )
+
+ shuffles = tuple(sorted(shuffles, key=lambda s: (s.train_fraction, s.index)))
+ return TrainingDatasetMetadata(
+ project_config=cfg,
+ shuffles=shuffles,
+ )
+
+ @staticmethod
+ def path(cfg: dict) -> Path:
+ """
+ Args:
+ cfg: the config for the DeepLabCut project
+
+ Returns:
+ the path to the training dataset metadata file
+ """
+ meta_path = auxiliaryfunctions.get_training_set_folder(cfg) / "metadata.yaml"
+ return Path(cfg["project_path"]) / meta_path
+
+
+def update_metadata(
+ cfg: dict,
+ train_fraction: float,
+ shuffle: int,
+ engine: Engine,
+ train_indices: list[int],
+ test_indices: list[int],
+ overwrite: bool = False,
+) -> None:
+ """Updates the metadata for a training-dataset
+
+ Args:
+ cfg: the config for the DeepLabCut project
+ train_fraction: the train_fraction of the new shuffle
+ shuffle: the index of the shuffle to add
+ engine: the engine for the shuffle
+ train_indices: the indices of images in the training set
+ test_indices: the indices of images in the test set
+ overwrite: whether to overwrite a shuffle with the same index and train fraction
+ if one exists
+
+ Raises:
+ ValueError: if overwrite=False and there is already a shuffle with the given
+ index in the metadata file.
+ """
+ prefix = cfg["Task"] + cfg["date"]
+ metadata = TrainingDatasetMetadata.load(cfg, load_splits=True)
+ new_shuffle = ShuffleMetadata(
+ name=f"{prefix}-trainset{int(100 * train_fraction)}shuffle{shuffle}",
+ train_fraction=train_fraction,
+ index=shuffle,
+ engine=engine,
+ split=DataSplit(
+ train_indices=tuple(sorted([int(i) for i in train_indices])),
+ test_indices=tuple(sorted([int(i) for i in test_indices])),
+ )
+ )
+ metadata = metadata.add(shuffle=new_shuffle, overwrite=overwrite)
+ metadata.save()
+
+
+def get_shuffle_engine(
+ cfg: dict,
+ trainingsetindex: int,
+ shuffle: int,
+ modelprefix: str = "",
+) -> Engine:
+ """
+ Args:
+ cfg: the config for the DeepLabCut project
+ trainingsetindex: the training set index used
+ shuffle: the shuffle for which to get the engine
+ modelprefix: the model prefix, if there is one
+
+ Returns:
+ the engine that the shuffle was created with
+
+ Raises:
+ ValueError if the engine for the shuffle cannot be determined or the shuffle
+ doesn't exist
+ """
+ if not TrainingDatasetMetadata.path(cfg).exists():
+ metadata = TrainingDatasetMetadata.create(cfg)
+ metadata.save()
+
+ metadata = TrainingDatasetMetadata.load(cfg)
+ shuffle_metadata = metadata.get(trainingsetindex, shuffle)
+ if modelprefix:
+ # try to get the engine by checking which models folder exists
+ engines = find_engines_from_model_folders(
+ cfg, trainingsetindex, shuffle, modelprefix
+ )
+ if len(engines) == 0:
+ raise ValueError(
+ f"Couldn't find any shuffles with trainingsetindex={trainingsetindex}, "
+ f"shuffle={shuffle} and modelprefix={modelprefix}. Please check that "
+ f"such a shuffle is defined."
+ )
+
+ if len(engines) == 1:
+ return engines.pop()
+
+ if shuffle_metadata.engine in engines:
+ engine = shuffle_metadata.engine
+ else:
+ engine = engines.pop() # take a random engine
+
+ logging.warning(
+ f"Found multiple engines for trainingsetindex={trainingsetindex}, "
+ f"shuffle={shuffle} and modelprefix={modelprefix}. Using engine={engine}. "
+ f"To select another engine, please specify it in your API call."
+ )
+ return engine
+
+ return shuffle_metadata.engine
+
+
+def find_engines_from_model_folders(
+ cfg: dict,
+ trainingsetindex: int,
+ shuffle: int,
+ modelprefix: str = "",
+) -> set[Engine]:
+ """Determines which engines are used with a given shuffle.
+
+ This method can be useful when using modelprefix, as the engine for a shuffle stored
+ under a "modelprefix" might not be the same as the base shuffle (for which the
+ engine is stored in the training-datasets folder).
+
+ Args:
+ cfg: the config for the DeepLabCut project
+ trainingsetindex: the training set index used
+ shuffle: the shuffle for which to get the engine
+ modelprefix: the model prefix, if there is one
+
+ Returns:
+ the engines for which a model folder exists for the given shuffle
+ """
+ project_path = Path(cfg["project_path"])
+ train_fraction = cfg["TrainingFraction"][trainingsetindex]
+
+ existing_engines = set()
+ for engine in Engine:
+ expected_model_folder = project_path / auxiliaryfunctions.get_model_folder(
+ trainFraction=train_fraction,
+ shuffle=shuffle,
+ cfg=cfg,
+ engine=engine,
+ modelprefix=modelprefix,
+ )
+ if expected_model_folder.exists():
+ existing_engines.add(engine)
+
+ return existing_engines
diff --git a/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py b/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py
index 10a41c4d6f..1095dfef6f 100755
--- a/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py
+++ b/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py
@@ -8,6 +8,7 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+from __future__ import annotations
import os
import os.path
@@ -19,6 +20,10 @@
import numpy as np
from tqdm import tqdm
+import deeplabcut.compat as compat
+import deeplabcut.generate_training_dataset.metadata as metadata
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.weight_init import WeightInitialization
from deeplabcut.generate_training_dataset import (
merge_annotateddatasets,
read_image_shape_fast,
@@ -27,6 +32,7 @@
MakeTest_pose_yaml,
MakeInference_yaml,
pad_train_test_indices,
+ validate_shuffles,
)
from deeplabcut.utils import (
auxiliaryfunctions,
@@ -101,6 +107,7 @@ def create_multianimaltraining_dataset(
Shuffles=None,
windows2linux=False,
net_type=None,
+ detector_type=None,
numdigits=2,
crop_size=(400, 400),
crop_sampling="hybrid",
@@ -109,15 +116,20 @@ def create_multianimaltraining_dataset(
testIndices=None,
n_edges_threshold=105,
paf_graph_degree=6,
+ userfeedback: bool = True,
+ weight_init: WeightInitialization | None = None,
+ engine: Engine | None = None,
):
"""
- Creates a training dataset for multi-animal datasets. Labels from all the extracted frames are merged into a single .h5 file.\n
+ Creates a training dataset for multi-animal datasets. Labels from all the extracted
+ frames are merged into a single .h5 file.\n
Only the videos included in the config file are used to create this dataset.\n
- [OPTIONAL] Use the function 'add_new_videos' at any stage of the project to add more videos to the project.
+ [OPTIONAL] Use the function 'add_new_videos' at any stage of the project to add more
+ videos to the project.
Important differences to standard:
- stores coordinates with numdigits as many digits
- - creates
+
Parameter
----------
config : string
@@ -130,17 +142,53 @@ def create_multianimaltraining_dataset(
Alternatively the user can also give a list of shuffles (integers!).
net_type: string
- Type of networks. Currently resnet_50, resnet_101, and resnet_152, efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3,
- efficientnet-b4, efficientnet-b5, and efficientnet-b6 as well as dlcrnet_ms5 are supported (not the MobileNets!).
- See Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1
+ Type of networks. The options available depend on which engine is used. See
+ Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1
+ Currently supported options are:
+ TensorFlow
+ * ``resnet_50``
+ * ``resnet_101``
+ * ``resnet_152``
+ * ``efficientnet-b0``
+ * ``efficientnet-b1``
+ * ``efficientnet-b2``
+ * ``efficientnet-b3``
+ * ``efficientnet-b4``
+ * ``efficientnet-b5``
+ * ``efficientnet-b6``
+ PyTorch (call ``deeplabcut.pose_estimation_pytorch.available_models()`` for
+ a complete list)
+ * ``resnet_50``
+ * ``resnet_101``
+ * ``dekr_w18``
+ * ``dekr_w32``
+ * ``dekr_w48``
+ * ``top_down_resnet_50``
+ * ``top_down_resnet_101``
+ * ``top_down_hrnet_w18``
+ * ``top_down_hrnet_w32``
+ * ``top_down_hrnet_w48``
+ * ``animaltokenpose_base``
+
+ detector_type: string, optional, default=None
+ Only for the PyTorch engine.
+ When passing creating shuffles for top-down models, you can specify which
+ detector you want. If the detector_type is None, the ```ssdlite``` will be used.
+ The list of all available detectors can be obtained by calling
+ ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported options:
+ * ``ssdlite``
+ * ``fasterrcnn_mobilenet_v3_large_fpn``
+ * ``fasterrcnn_resnet50_fpn_v2``
numdigits: int, optional
crop_size: tuple of int, optional
+ Only for the TensorFlow engine.
Dimensions (width, height) of the crops for data augmentation.
Default is 400x400.
crop_sampling: str, optional
+ Only for the TensorFlow engine.
Crop centers sampling method. Must be either:
"uniform" (randomly over the image),
"keypoints" (randomly over the annotated keypoints),
@@ -149,6 +197,7 @@ def create_multianimaltraining_dataset(
Default is "hybrid".
paf_graph: list of lists, or "config" optional (default=None)
+ Only for the TensorFlow engine.
If not None, overwrite the default complete graph. This is useful for advanced users who
already know a good graph, or simply want to use a specific one. Note that, in that case,
the data-driven selection procedure upon model evaluation will be skipped.
@@ -163,11 +212,27 @@ def create_multianimaltraining_dataset(
List of one or multiple lists containing test indexes.
n_edges_threshold: int, optional (default=105)
+ Only for the TensorFlow engine.
Number of edges above which the graph is automatically pruned.
paf_graph_degree: int, optional (default=6)
+ Only for the TensorFlow engine.
Degree of paf_graph when automatically pruning it (before training).
+ userfeedback: bool, optional, default=True
+ If ``False``, all requested train/test splits are created (no matter if they
+ already exist). If you want to assure that previous splits etc. are not
+ overwritten, set this to ``True`` and you will be asked for each split.
+
+ weight_init: WeightInitialisation, optional, default=None
+ PyTorch engine only. Specify how model weights should be initialized. The
+ default mode uses transfer learning from ImageNet weights.
+
+ engine: Engine, optional
+ Whether to create a pose config for a Tensorflow or PyTorch model. Defaults to
+ the value specified in the project configuration file. If no engine is specified
+ for the project, defaults to ``deeplabcut.compat.DEFAULT_ENGINE``.
+
Example
--------
>>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1)
@@ -202,6 +267,11 @@ def create_multianimaltraining_dataset(
full_training_path = Path(project_path, trainingsetfolder)
auxiliaryfunctions.attempt_to_make_folder(full_training_path, recursive=True)
+ # Create the trainset metadata file, if it doesn't yet exist
+ if not metadata.TrainingDatasetMetadata.path(cfg).exists():
+ trainset_metadata = metadata.TrainingDatasetMetadata.create(cfg)
+ trainset_metadata.save()
+
Data = merge_annotateddatasets(cfg, full_training_path)
if Data is None:
return
@@ -209,13 +279,21 @@ def create_multianimaltraining_dataset(
if net_type is None: # loading & linking pretrained models
net_type = cfg.get("default_net_type", "dlcrnet_ms5")
- elif not any(net in net_type for net in ("resnet", "eff", "dlc", "mob")):
- raise ValueError(f"Unsupported network {net_type}.")
+
+ # load the engine to use to create the shuffle
+ if engine is None:
+ engine = compat.get_project_engine(cfg)
+
+ if not (
+ any(net in net_type for net in ("resnet", "eff", "dlc", "mob"))
+ or engine == Engine.PYTORCH
+ ):
+ raise ValueError(f"Unsupported network {net_type} for engine {engine}.")
multi_stage = False
### dlcnet_ms5: backbone resnet50 + multi-fusion & multi-stage module
### dlcr101_ms5/dlcr152_ms5: backbone resnet101/152 + multi-fusion & multi-stage module
- if all(net in net_type for net in ("dlcr", "_ms5")):
+ if all(net in net_type for net in ("dlcr", "_ms5")) and engine != Engine.PYTORCH:
num_layers = re.findall("dlcr([0-9]*)", net_type)[0]
if num_layers == "":
num_layers = 50
@@ -272,12 +350,13 @@ def create_multianimaltraining_dataset(
# Loading the encoder (if necessary downloading from TF)
dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
- model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
- if Shuffles is None:
- Shuffles = range(1, num_shuffles + 1, 1)
+ if engine == Engine.PYTORCH:
+ model_path = dlcparent_path
else:
- Shuffles = [i for i in Shuffles if isinstance(i, int)]
+ model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
+
+ Shuffles = validate_shuffles(cfg, Shuffles, num_shuffles, userfeedback)
# print(trainIndices,testIndices, Shuffles, augmenter_type,net_type)
if trainIndices is None and testIndices is None:
@@ -309,6 +388,11 @@ def create_multianimaltraining_dataset(
test_inds = test_inds[test_inds != -1]
splits.append((trainFraction, Shuffles[shuffle], (train_inds, test_inds)))
+ top_down = False
+ if engine == Engine.PYTORCH and net_type.startswith("top_down_"):
+ top_down = True
+ net_type = net_type[len("top_down_") :]
+
for trainFraction, shuffle, (trainIndices, testIndices) in splits:
####################################################
# Generating data structure with labeled information & frame metadata (for deep cut)
@@ -345,6 +429,15 @@ def create_multianimaltraining_dataset(
testIndices,
trainFraction,
)
+ metadata.update_metadata(
+ cfg=cfg,
+ train_fraction=trainFraction,
+ shuffle=shuffle,
+ engine=engine,
+ train_indices=trainIndices,
+ test_indices=testIndices,
+ overwrite=not userfeedback,
+ )
datafilename = datafilename.split(".mat")[0] + ".pickle"
import pickle
@@ -359,7 +452,10 @@ def create_multianimaltraining_dataset(
#################################################################################
modelfoldername = auxiliaryfunctions.get_model_folder(
- trainFraction, shuffle, cfg
+ trainFraction,
+ shuffle,
+ cfg,
+ engine=engine,
)
auxiliaryfunctions.attempt_to_make_folder(
Path(config).parents[0] / modelfoldername, recursive=True
@@ -396,88 +492,126 @@ def create_multianimaltraining_dataset(
)
)
- jointnames = [str(bpt) for bpt in multianimalbodyparts]
- jointnames.extend([str(bpt) for bpt in uniquebodyparts])
- items2change = {
- "dataset": datafilename,
- "metadataset": metadatafilename,
- "num_joints": len(multianimalbodyparts)
- + len(uniquebodyparts), # cfg["uniquebodyparts"]),
- "all_joints": [
- [i] for i in range(len(multianimalbodyparts) + len(uniquebodyparts))
- ], # cfg["uniquebodyparts"]))],
- "all_joints_names": jointnames,
- "init_weights": model_path,
- "project_path": str(cfg["project_path"]),
- "net_type": net_type,
- "multi_stage": multi_stage,
- "pairwise_loss_weight": 0.1,
- "pafwidth": 20,
- "partaffinityfield_graph": partaffinityfield_graph,
- "partaffinityfield_predict": partaffinityfield_predict,
- "weigh_only_present_joints": False,
- "num_limbs": len(partaffinityfield_graph),
- "dataset_type": dataset_type,
- "optimizer": "adam",
- "batch_size": 8,
- "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]],
- "save_iters": 10000,
- "display_iters": 500,
- "num_idchannel": (
- len(cfg["individuals"]) if cfg.get("identity", False) else 0
- ),
- "crop_size": list(crop_size),
- "crop_sampling": crop_sampling,
- }
+ if engine == Engine.TF:
+ jointnames = [str(bpt) for bpt in multianimalbodyparts]
+ jointnames.extend([str(bpt) for bpt in uniquebodyparts])
+ items2change = {
+ "dataset": datafilename,
+ "engine": engine.aliases[0],
+ "metadataset": metadatafilename,
+ "num_joints": len(multianimalbodyparts)
+ + len(uniquebodyparts), # cfg["uniquebodyparts"]),
+ "all_joints": [
+ [i]
+ for i in range(len(multianimalbodyparts) + len(uniquebodyparts))
+ ], # cfg["uniquebodyparts"]))],
+ "all_joints_names": jointnames,
+ "init_weights": str(model_path),
+ "project_path": str(cfg["project_path"]),
+ "net_type": net_type,
+ "multi_stage": multi_stage,
+ "pairwise_loss_weight": 0.1,
+ "pafwidth": 20,
+ "partaffinityfield_graph": partaffinityfield_graph,
+ "partaffinityfield_predict": partaffinityfield_predict,
+ "weigh_only_present_joints": False,
+ "num_limbs": len(partaffinityfield_graph),
+ "dataset_type": dataset_type,
+ "optimizer": "adam",
+ "batch_size": 8,
+ "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]],
+ "save_iters": 10000,
+ "display_iters": 500,
+ "num_idchannel": (
+ len(cfg["individuals"]) if cfg.get("identity", False) else 0
+ ),
+ "crop_size": list(crop_size),
+ "crop_sampling": crop_sampling,
+ }
+
+ trainingdata = MakeTrain_pose_yaml(
+ items2change,
+ path_train_config,
+ defaultconfigfile,
+ save=(engine == Engine.TF),
+ )
+ keys2save = [
+ "dataset",
+ "num_joints",
+ "all_joints",
+ "all_joints_names",
+ "net_type",
+ "multi_stage",
+ "init_weights",
+ "global_scale",
+ "location_refinement",
+ "locref_stdev",
+ "dataset_type",
+ "partaffinityfield_predict",
+ "pairwise_predict",
+ "partaffinityfield_graph",
+ "num_limbs",
+ "dataset_type",
+ "num_idchannel",
+ ]
+
+ MakeTest_pose_yaml(
+ trainingdata,
+ keys2save,
+ path_test_config,
+ nmsradius=5.0,
+ minconfidence=0.01,
+ sigma=1,
+ locref_smooth=False,
+ ) # setting important def. values for inference
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
+ make_pytorch_pose_config,
+ make_pytorch_test_config,
+ )
+ from deeplabcut.pose_estimation_pytorch.modelzoo.config import (
+ make_super_animal_finetune_config,
+ )
- trainingdata = MakeTrain_pose_yaml(
- items2change, path_train_config, defaultconfigfile
- )
- keys2save = [
- "dataset",
- "num_joints",
- "all_joints",
- "all_joints_names",
- "net_type",
- "multi_stage",
- "init_weights",
- "global_scale",
- "location_refinement",
- "locref_stdev",
- "dataset_type",
- "partaffinityfield_predict",
- "pairwise_predict",
- "partaffinityfield_graph",
- "num_limbs",
- "dataset_type",
- "num_idchannel",
- ]
+ # backwards compatibility with version 2.X
+ if net_type == "dlcrnet_ms5":
+ net_type = "dlcrnet_stride16_ms5"
+
+ config_path = Path(path_train_config).with_name(engine.pose_cfg_name)
+ if weight_init is not None and weight_init.with_decoder:
+ pytorch_cfg = make_super_animal_finetune_config(
+ project_config=cfg,
+ pose_config_path=config_path,
+ model_name=net_type,
+ detector_name=detector_type,
+ weight_init=weight_init,
+ save=True,
+ )
+ else:
+ pytorch_cfg = make_pytorch_pose_config(
+ project_config=cfg,
+ pose_config_path=config_path,
+ net_type=net_type,
+ top_down=top_down,
+ detector_type=detector_type,
+ weight_init=weight_init,
+ save=True,
+ )
- MakeTest_pose_yaml(
- trainingdata,
- keys2save,
- path_test_config,
- nmsradius=5.0,
- minconfidence=0.01,
- sigma=1,
- locref_smooth=False,
- ) # setting important def. values for inference
+ make_pytorch_test_config(pytorch_cfg, path_test_config, save=True)
# Setting inference cfg file:
- defaultinference_configfile = os.path.join(
- dlcparent_path, "inference_cfg.yaml"
- )
- items2change = {
- "minimalnumberofconnections": int(len(cfg["multianimalbodyparts"]) / 2),
- "topktoretain": len(cfg["individuals"]),
- "withid": cfg.get("identity", False),
- }
- MakeInference_yaml(
- items2change, path_inference_config, defaultinference_configfile
+ default_inf_path = Path(dlcparent_path) / "inference_cfg.yaml"
+ inf_updates = dict(
+ minimalnumberofconnections=int(len(cfg["multianimalbodyparts"]) / 2),
+ topktoretain=len(cfg["individuals"]),
+ withid=cfg.get("identity", False),
)
+ MakeInference_yaml(inf_updates, path_inference_config, default_inf_path)
print(
- "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!"
+ "The training dataset is successfully created. Use the function "
+ "'train_network' to start training. Happy training!"
)
else:
pass
diff --git a/deeplabcut/generate_training_dataset/trainingsetmanipulation.py b/deeplabcut/generate_training_dataset/trainingsetmanipulation.py
index 244ab2fcb5..08f7694633 100755
--- a/deeplabcut/generate_training_dataset/trainingsetmanipulation.py
+++ b/deeplabcut/generate_training_dataset/trainingsetmanipulation.py
@@ -8,6 +8,7 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+from __future__ import annotations
import math
import logging
@@ -24,7 +25,10 @@
import pandas as pd
import yaml
-from deeplabcut.pose_estimation_tensorflow import training
+import deeplabcut.compat as compat
+import deeplabcut.generate_training_dataset.metadata as metadata
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.weight_init import WeightInitialization
from deeplabcut.utils import (
auxiliaryfunctions,
conversioncode,
@@ -32,8 +36,6 @@
auxfun_multianimal,
)
from deeplabcut.utils.auxfun_videos import VideoReader
-from deeplabcut.pose_estimation_tensorflow.config import load_config
-from deeplabcut.modelzoo.utils import parse_available_supermodels
def comparevideolistsanddatafolders(config):
@@ -397,19 +399,26 @@ def ParseYaml(configfile):
def MakeTrain_pose_yaml(
- itemstochange, saveasconfigfile, defaultconfigfile, items2drop={}
+ itemstochange,
+ saveasconfigfile,
+ defaultconfigfile,
+ items2drop: dict | None = None,
+ save: bool = True,
):
+ if items2drop is None:
+ items2drop = {}
+
docs = ParseYaml(defaultconfigfile)
for key in items2drop.keys():
- # print(key, "dropping?")
if key in docs[0].keys():
docs[0].pop(key)
for key in itemstochange.keys():
docs[0][key] = itemstochange[key]
- with open(saveasconfigfile, "w") as f:
- yaml.dump(docs[0], f)
+ if save:
+ with open(saveasconfigfile, "w") as f:
+ yaml.dump(docs[0], f)
return docs[0]
@@ -772,13 +781,16 @@ def create_training_dataset(
num_shuffles=1,
Shuffles=None,
windows2linux=False,
- userfeedback=False,
+ userfeedback=True,
trainIndices=None,
testIndices=None,
net_type=None,
+ detector_type=None,
augmenter_type=None,
posecfg_template=None,
superanimal_name="",
+ weight_init: WeightInitialization | None = None,
+ engine: Engine | None = None,
):
"""Creates a training dataset.
@@ -797,7 +809,7 @@ def create_training_dataset(
Shuffles: list[int], optional
Alternatively the user can also give a list of shuffles.
- userfeedback: bool, optional, default=False
+ userfeedback: bool, optional, default=True
If ``False``, all requested train/test splits are created (no matter if they
already exist). If you want to assure that previous splits etc. are not
overwritten, set this to ``True`` and you will be asked for each split.
@@ -810,41 +822,83 @@ def create_training_dataset(
List of one or multiple lists containing test indexes.
net_type: list, optional, default=None
- Type of networks. Currently supported options are
-
- * ``resnet_50``
- * ``resnet_101``
- * ``resnet_152``
- * ``mobilenet_v2_1.0``
- * ``mobilenet_v2_0.75``
- * ``mobilenet_v2_0.5``
- * ``mobilenet_v2_0.35``
- * ``efficientnet-b0``
- * ``efficientnet-b1``
- * ``efficientnet-b2``
- * ``efficientnet-b3``
- * ``efficientnet-b4``
- * ``efficientnet-b5``
- * ``efficientnet-b6``
+ Type of networks. The options available depend on which engine is used.
+ Currently supported options are:
+ TensorFlow
+ * ``resnet_50``
+ * ``resnet_101``
+ * ``resnet_152``
+ * ``mobilenet_v2_1.0``
+ * ``mobilenet_v2_0.75``
+ * ``mobilenet_v2_0.5``
+ * ``mobilenet_v2_0.35``
+ * ``efficientnet-b0``
+ * ``efficientnet-b1``
+ * ``efficientnet-b2``
+ * ``efficientnet-b3``
+ * ``efficientnet-b4``
+ * ``efficientnet-b5``
+ * ``efficientnet-b6``
+ PyTorch (call ``deeplabcut.pose_estimation_pytorch.available_models()`` for
+ a complete list)
+ * ``resnet_50``
+ * ``resnet_101``
+ * ``hrnet_w18``
+ * ``hrnet_w32``
+ * ``hrnet_w48``
+ * ``dekr_w18``
+ * ``dekr_w32``
+ * ``dekr_w48``
+ * ``top_down_resnet_50``
+ * ``top_down_resnet_101``
+ * ``top_down_hrnet_w18``
+ * ``top_down_hrnet_w32``
+ * ``top_down_hrnet_w48``
+ * ``animaltokenpose_base``
+
+ detector_type: string, optional, default=None
+ Only for the PyTorch engine.
+ When passing creating shuffles for top-down models, you can specify which
+ detector you want. If the detector_type is None, the ```ssdlite``` will be used.
+ The list of all available detectors can be obtained by calling
+ ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported options:
+ * ``ssdlite``
+ * ``fasterrcnn_mobilenet_v3_large_fpn``
+ * ``fasterrcnn_resnet50_fpn_v2``
augmenter_type: string, optional, default=None
- Type of augmenter. Currently supported augmenters are
-
- * ``default``
- * ``scalecrop``
- * ``imgaug``
- * ``tensorpack``
- * ``deterministic``
+ Type of augmenter. The options available depend on which engine is used.
+ Currently supported options are:
+ TensorFlow
+ * ``default``
+ * ``scalecrop``
+ * ``imgaug``
+ * ``tensorpack``
+ * ``deterministic``
+ PyTorch
+ * ``albumentations``
posecfg_template: string, optional, default=None
+ Only for the TensorFlow engine.
Path to a ``pose_cfg.yaml`` file to use as a template for generating the new
one for the current iteration. Useful if you would like to start with the same
parameters a previous training iteration. None uses the default
``pose_cfg.yaml``.
superanimal_name: string, optional, default=""
- Specify the superanimal name is transfer learning with superanimal is desired. This makes sure the pose config template uses superanimal configs as template
+ Only for the TensorFlow engine. For the PyTorch engine, use the ``weight_init``
+ parameter.
+ Specify the superanimal name is transfer learning with superanimal is desired.
+ This makes sure the pose config template uses superanimal configs as template.
+ weight_init: WeightInitialisation, optional, default=None
+ PyTorch engine only. Specify how model weights should be initialized. The
+ default mode uses transfer learning from ImageNet weights.
+
+ engine: Engine, optional
+ Whether to create a pose config for a Tensorflow or PyTorch model. Defaults to
+ the value specified in the project configuration file. If no engine is specified
+ for the project, defaults to ``deeplabcut.compat.DEFAULT_ENGINE``.
Returns
-------
@@ -890,6 +944,7 @@ def create_training_dataset(
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
if superanimal_name != "":
+ # FIXME(niels): this is deprecated
supermodels = parse_available_supermodels()
posecfg_template = os.path.join(
dlc_root_path,
@@ -922,12 +977,19 @@ def create_training_dataset(
num_shuffles,
Shuffles,
net_type=net_type,
+ detector_type=detector_type,
trainIndices=trainIndices,
testIndices=testIndices,
+ userfeedback=userfeedback,
+ engine=engine,
+ weight_init=weight_init,
)
else:
scorer = cfg["scorer"]
project_path = cfg["project_path"]
+ if engine is None:
+ engine = compat.get_project_engine(cfg)
+
# Create path for training sets & store data there
trainingsetfolder = auxiliaryfunctions.get_training_set_folder(
cfg
@@ -936,6 +998,11 @@ def create_training_dataset(
Path(os.path.join(project_path, str(trainingsetfolder))), recursive=True
)
+ # Create the trainset metadata file, if it doesn't yet exist
+ if not metadata.TrainingDatasetMetadata.path(cfg).exists():
+ trainset_metadata = metadata.TrainingDatasetMetadata.create(cfg)
+ trainset_metadata.save()
+
Data = merge_annotateddatasets(
cfg,
Path(os.path.join(project_path, trainingsetfolder)),
@@ -947,6 +1014,8 @@ def create_training_dataset(
# loading & linking pretrained models
if net_type is None: # loading & linking pretrained models
net_type = cfg.get("default_net_type", "resnet_50")
+ elif engine == Engine.PYTORCH:
+ pass
else:
if (
"resnet" in net_type
@@ -958,20 +1027,40 @@ def create_training_dataset(
else:
raise ValueError("Invalid network type:", net_type)
+ top_down = False
+ if engine == Engine.PYTORCH:
+ if net_type.startswith("top_down_"):
+ top_down = True
+ net_type = net_type[len("top_down_") :]
+
+ augmenters = compat.get_available_aug_methods(engine)
+ default_augmenter = augmenters[0]
if augmenter_type is None:
- augmenter_type = cfg.get("default_augmenter", "imgaug")
+ augmenter_type = cfg.get("default_augmenter", default_augmenter)
+
if augmenter_type is None: # this could be in config.yaml for old projects!
# updating variable if null/None! #backwardscompatability
- auxiliaryfunctions.edit_config(config, {"default_augmenter": "imgaug"})
- augmenter_type = "imgaug"
- elif augmenter_type not in [
- "default",
- "scalecrop",
- "imgaug",
- "tensorpack",
- "deterministic",
- ]:
- raise ValueError("Invalid augmenter type:", augmenter_type)
+ augmenter_type = default_augmenter
+ auxiliaryfunctions.edit_config(
+ config, {"default_augmenter": augmenter_type}
+ )
+ elif augmenter_type not in augmenters:
+ # as the default augmenter might not be available for the given engine
+ augmenter_type = default_augmenter
+ logging.info(
+ f"Default augmenter {augmenter_type} not available for engine "
+ f"{engine}: using {default_augmenter} instead"
+ )
+
+ if augmenter_type not in augmenters:
+ if engine != Engine.PYTORCH:
+ raise ValueError(
+ f"Invalid augmenter type: {augmenter_type} (available: for "
+ f"engine={engine}: {augmenters})"
+ )
+
+ logging.info(f"Switching augmentation to {default_augmenter} for PyTorch")
+ augmenter_type = default_augmenter
if posecfg_template:
if net_type != prior_cfg["net_type"]:
@@ -989,12 +1078,13 @@ def create_training_dataset(
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
elif posecfg_template:
defaultconfigfile = posecfg_template
- model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
- if Shuffles is None:
- Shuffles = range(1, num_shuffles + 1)
+ if engine == Engine.PYTORCH:
+ model_path = dlcparent_path
else:
- Shuffles = [i for i in Shuffles if isinstance(i, int)]
+ model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
+
+ Shuffles = validate_shuffles(cfg, Shuffles, num_shuffles, userfeedback)
# print(trainIndices,testIndices, Shuffles, augmenter_type,net_type)
if trainIndices is None and testIndices is None:
@@ -1032,15 +1122,16 @@ def create_training_dataset(
(trainFraction, Shuffles[shuffle], (train_inds, test_inds))
)
- bodyparts = cfg["bodyparts"]
+ bodyparts = auxiliaryfunctions.get_bodyparts(cfg)
nbodyparts = len(bodyparts)
for trainFraction, shuffle, (trainIndices, testIndices) in splits:
if len(trainIndices) > 0:
if userfeedback:
- trainposeconfigfile, _, _ = training.return_train_network_path(
+ trainposeconfigfile, _, _ = compat.return_train_network_path(
config,
shuffle=shuffle,
trainingsetindex=cfg["TrainingFraction"].index(trainFraction),
+ engine=engine,
)
if trainposeconfigfile.is_file():
askuser = input(
@@ -1087,13 +1178,25 @@ def create_training_dataset(
testIndices,
trainFraction,
)
+ metadata.update_metadata(
+ cfg=cfg,
+ train_fraction=trainFraction,
+ shuffle=shuffle,
+ engine=engine,
+ train_indices=trainIndices,
+ test_indices=testIndices,
+ overwrite=not userfeedback,
+ )
################################################################################
# Creating file structure for training &
# Test files as well as pose_yaml files (containing training and testing information)
#################################################################################
modelfoldername = auxiliaryfunctions.get_model_folder(
- trainFraction, shuffle, cfg
+ trainFraction,
+ shuffle,
+ cfg,
+ engine=engine,
)
auxiliaryfunctions.attempt_to_make_folder(
Path(config).parents[0] / modelfoldername, recursive=True
@@ -1110,7 +1213,7 @@ def create_training_dataset(
cfg["project_path"],
Path(modelfoldername),
"train",
- "pose_cfg.yaml",
+ engine.pose_cfg_name,
)
)
path_test_config = str(
@@ -1121,68 +1224,204 @@ def create_training_dataset(
"pose_cfg.yaml",
)
)
- # str(cfg['proj_path']+'/'+Path(modelfoldername) / 'test' / 'pose_cfg.yaml')
- items2change = {
- "dataset": datafilename,
- "metadataset": metadatafilename,
- "num_joints": len(bodyparts),
- "all_joints": [[i] for i in range(len(bodyparts))],
- "all_joints_names": [str(bpt) for bpt in bodyparts],
- "init_weights": model_path,
- "project_path": str(cfg["project_path"]),
- "net_type": net_type,
- "dataset_type": augmenter_type,
- }
-
- items2drop = {}
- if augmenter_type == "scalecrop":
- # these values are dropped as scalecrop
- # doesn't have rotation implemented
- items2drop = {"rotation": 0, "rotratio": 0.0}
- # Also drop maDLC smart cropping augmentation parameters
- for key in ["pre_resize", "crop_size", "max_shift", "crop_sampling"]:
- items2drop[key] = None
-
- trainingdata = MakeTrain_pose_yaml(
- items2change, path_train_config, defaultconfigfile, items2drop
- )
+ if engine == Engine.TF:
+ items2change = {
+ "dataset": datafilename,
+ "engine": engine.aliases[0],
+ "metadataset": metadatafilename,
+ "num_joints": len(bodyparts),
+ "all_joints": [[i] for i in range(len(bodyparts))],
+ "all_joints_names": [str(bpt) for bpt in bodyparts],
+ "init_weights": model_path,
+ "project_path": str(cfg["project_path"]),
+ "net_type": net_type,
+ "dataset_type": augmenter_type,
+ }
+
+ items2drop = {}
+ if augmenter_type == "scalecrop":
+ # these values are dropped as scalecrop
+ # doesn't have rotation implemented
+ items2drop = {"rotation": 0, "rotratio": 0.0}
+ # Also drop maDLC smart cropping augmentation parameters
+ for key in [
+ "pre_resize",
+ "crop_size",
+ "max_shift",
+ "crop_sampling",
+ ]:
+ items2drop[key] = None
+
+ trainingdata = MakeTrain_pose_yaml(
+ items2change,
+ path_train_config,
+ defaultconfigfile,
+ items2drop,
+ save=(engine == Engine.TF),
+ )
- keys2save = [
- "dataset",
- "num_joints",
- "all_joints",
- "all_joints_names",
- "net_type",
- "init_weights",
- "global_scale",
- "location_refinement",
- "locref_stdev",
- ]
- MakeTest_pose_yaml(trainingdata, keys2save, path_test_config)
- print(
- "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!"
- )
+ keys2save = [
+ "dataset",
+ "num_joints",
+ "all_joints",
+ "all_joints_names",
+ "net_type",
+ "init_weights",
+ "global_scale",
+ "location_refinement",
+ "locref_stdev",
+ ]
+ MakeTest_pose_yaml(trainingdata, keys2save, path_test_config)
+ print(
+ "The training dataset is successfully created. Use the function"
+ "'train_network' to start training. Happy training!"
+ )
+ elif engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
+ make_pytorch_pose_config,
+ make_pytorch_test_config,
+ )
+ from deeplabcut.pose_estimation_pytorch.modelzoo.config import (
+ make_super_animal_finetune_config,
+ )
+
+ if weight_init is not None and weight_init.with_decoder:
+ pytorch_cfg = make_super_animal_finetune_config(
+ project_config=cfg,
+ pose_config_path=path_train_config,
+ model_name=net_type,
+ detector_name=detector_type,
+ weight_init=weight_init,
+ save=True,
+ )
+ else:
+ pytorch_cfg = make_pytorch_pose_config(
+ project_config=cfg,
+ pose_config_path=path_train_config,
+ net_type=net_type,
+ top_down=top_down,
+ detector_type=detector_type,
+ weight_init=weight_init,
+ save=True,
+ )
+
+ make_pytorch_test_config(pytorch_cfg, path_test_config, save=True)
return splits
def get_largestshuffle_index(config):
"""Returns the largest shuffle for all dlc-models in the current iteration."""
- cfg = auxiliaryfunctions.read_config(config)
- project_path = cfg["project_path"]
- iterate = "iteration-" + str(cfg["iteration"])
- dlc_model_path = os.path.join(project_path, "dlc-models", iterate)
- if os.path.isdir(dlc_model_path):
- models = os.listdir(dlc_model_path)
- # sort the model directories
- models.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
-
- # get the shuffle index and offset by 1.
- max_shuffle_index = int(models[-1].split("shuffle")[-1]) + 1
+ shuffle_indices = get_existing_shuffle_indices(config)
+ if len(shuffle_indices) > 0:
+ return shuffle_indices[-1]
+
+ return None
+
+
+def get_existing_shuffle_indices(
+ cfg: dict | str | Path,
+ train_fraction: float | None = None,
+ engine: Engine | None = None,
+) -> List[int]:
+ """
+ Args:
+ cfg: The content of a project configuration file, or the path to the project
+ configuration file.
+ train_fraction: If defined, only get the indices of shuffles with this train
+ fraction.
+ engine: If specified, returns only the shuffle indices that were created with
+ the given engine. Can only be used when train_fraction is also defined.
+
+ Returns:
+ the indices of existing shuffles for this iteration of the project, sorted by
+ ascending index
+ """
+
+ def is_valid_data_stem(stem: str) -> bool:
+ if len(stem) == 0:
+ return False
+ suffix = stem.split("_")[-1]
+ if len(suffix) == 0:
+ return False
+ info = suffix.split("shuffle")
+ if len(info) != 2:
+ return False
+ train_frac, idx = info
+ return (
+ train_frac.isdigit()
+ and idx.isdigit()
+ and (train_fraction is None or int(train_frac) == int(100 * train_fraction))
+ )
+
+ if isinstance(cfg, (str, Path)):
+ cfg = auxiliaryfunctions.read_config(cfg)
+
+ project = Path(cfg["project_path"])
+ trainset_folder = project / auxiliaryfunctions.get_training_set_folder(cfg)
+ if not trainset_folder.exists():
+ return []
+
+ shuffle_indices = [
+ int(p.stem.split("shuffle")[-1])
+ for p in trainset_folder.iterdir()
+ if (
+ p.stem.startswith("Documentation_data")
+ and p.suffix == ".pickle"
+ and is_valid_data_stem(p.stem)
+ )
+ ]
+ if engine is not None:
+ if train_fraction is None:
+ raise ValueError(
+ f"Must select {train_fraction} to filter shuffles by engine"
+ )
+
+ shuffle_indices = [
+ idx
+ for idx in shuffle_indices
+ if (
+ project
+ / auxiliaryfunctions.get_model_folder(
+ trainFraction=train_fraction,
+ shuffle=idx,
+ cfg=cfg,
+ engine=engine,
+ )
+ ).exists()
+ ]
+
+ return sorted(shuffle_indices)
+
+
+def validate_shuffles(
+ cfg: dict,
+ shuffles: list[int] | None,
+ num_shuffles: int | None,
+ userfeedback: bool,
+) -> list[int]:
+ existing_shuffles = get_existing_shuffle_indices(cfg)
+ if shuffles is None:
+ first_index = 1
+ if len(existing_shuffles) > 0:
+ first_index = existing_shuffles[-1] + 1
+
+ shuffles = range(first_index, num_shuffles + first_index)
else:
- max_shuffle_index = 0
+ shuffles = [i for i in shuffles if isinstance(i, int)]
+ for shuffle_idx in shuffles:
+ if userfeedback and shuffle_idx in existing_shuffles:
+ raise ValueError(
+ f"Cannot create shuffle {shuffle_idx} as it already exists - "
+ f"you must either create the dataset with `userfeedback=False` "
+ f"or delete the shuffle with index {shuffle_idx} manually (in "
+ f"`dlc-models`/`dlc-models-pytorch` and in the "
+ f"`training-datasets` folder) if you want to create a new "
+ f"shuffle with that index. You can otherwise create a shuffle "
+ f"with a new index. Existing indices are {existing_shuffles}."
+ )
- return max_shuffle_index
+ return shuffles
def create_training_model_comparison(
@@ -1301,7 +1540,7 @@ def create_training_model_comparison(
else:
pass
- largestshuffleindex = get_largestshuffle_index(config)
+ largestshuffleindex = get_existing_shuffle_indices(cfg)[-1] + 1
shuffle_list = []
for shuffle in range(num_shuffles):
@@ -1342,3 +1581,183 @@ def create_training_model_comparison(
logger.info(log_info)
return shuffle_list
+
+
+def create_training_dataset_from_existing_split(
+ config: str,
+ from_shuffle: int,
+ from_trainsetindex: int = 0,
+ num_shuffles: int = 1,
+ shuffles: list[int] | None = None,
+ userfeedback: bool = True,
+ net_type: str | None = None,
+ detector_type: str | None = None,
+ augmenter_type: str | None = None,
+ posecfg_template: dict | None = None,
+ superanimal_name: str = "",
+ weight_init: WeightInitialization | None = None,
+ engine: Engine | None = None,
+) -> None | list[int]:
+ """
+ Labels from all the extracted frames are merged into a single .h5 file.
+ Only the videos included in the config file are used to create this dataset.
+
+ Args:
+ config: Full path of the ``config.yaml`` file as a string.
+
+ from_shuffle: The index of the shuffle from which to copy the train/test split.
+
+ from_trainsetindex: The trainset index of the shuffle from which to use the data
+ split. Default is 0.
+
+ num_shuffles: Number of shuffles of training dataset to create, used if
+ ``shuffles`` is None.
+
+ shuffles: If defined, ``num_shuffles`` is ignored and a shuffle is created for
+ each index given in the list.
+
+ userfeedback: If ``False``, all requested train/test splits are created (no
+ matter if they already exist). If you want to assure that previous splits
+ etc. are not overwritten, set this to ``True`` and you will be asked for
+ each existing split if you want to overwrite it.
+
+ net_type: The type of network to create the shuffle for. Currently supported
+ options for engine=Engine.TF are:
+ * ``resnet_50``
+ * ``resnet_101``
+ * ``resnet_152``
+ * ``mobilenet_v2_1.0``
+ * ``mobilenet_v2_0.75``
+ * ``mobilenet_v2_0.5``
+ * ``mobilenet_v2_0.35``
+ * ``efficientnet-b0``
+ * ``efficientnet-b1``
+ * ``efficientnet-b2``
+ * ``efficientnet-b3``
+ * ``efficientnet-b4``
+ * ``efficientnet-b5``
+ * ``efficientnet-b6``
+ Currently supported options for engine=Engine.TF can be obtained by calling
+ ``deeplabcut.pose_estimation_pytorch.available_models()``.
+
+ detector_type: string, optional, default=None
+ Only for the PyTorch engine.
+ When passing creating shuffles for top-down models, you can specify which
+ detector you want. If the detector_type is None, the ```ssdlite``` will be
+ used. The list of all available detectors can be obtained by calling
+ ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported
+ options:
+ * ``ssdlite``
+ * ``fasterrcnn_mobilenet_v3_large_fpn``
+ * ``fasterrcnn_resnet50_fpn_v2``
+
+ augmenter_type: Type of augmenter. Currently supported augmenters for
+ engine=Engine.TF are
+ * ``default``
+ * ``scalecrop``
+ * ``imgaug``
+ * ``tensorpack``
+ * ``deterministic``
+ The only supported augmenter for Engine.PYTORCH is ``albumentations``.
+
+ posecfg_template: Only for Engine.TF. Path to a ``pose_cfg.yaml`` file to use as
+ a template for generating the new one for the current iteration. Useful if
+ you would like to start with the same parameters a previous training
+ iteration. None uses the default ``pose_cfg.yaml``.
+
+ superanimal_name: Specify the superanimal name is transfer learning with
+ superanimal is desired. This makes sure the pose config template uses
+ superanimal configs as template.
+
+ weight_init: Only for Engine.PYTORCH. Specify how model weights should be
+ initialized. The default mode uses transfer learning from ImageNet weights.
+
+ engine: Whether to create a pose config for a Tensorflow or PyTorch model.
+ Defaults to the value specified in the project configuration file. If no
+ engine is specified for the project, defaults to
+ ``deeplabcut.compat.DEFAULT_ENGINE``.
+
+ Returns:
+ If training dataset was successfully created, a list of tuples is returned.
+ The first two elements in each tuple represent the training fraction and the
+ shuffle value. The last two elements in each tuple are arrays of integers
+ representing the training and test indices.
+
+ Returns None if training dataset could not be created.
+
+ Raises:
+ ValueError: If the shuffle from which to copy the data split doesn't exist.
+ """
+ cfg = auxiliaryfunctions.read_config(config)
+ trainset_meta_path = metadata.TrainingDatasetMetadata.path(cfg)
+ if not trainset_meta_path.exists():
+ meta = metadata.TrainingDatasetMetadata.create(cfg)
+ meta.save()
+ else:
+ meta = metadata.TrainingDatasetMetadata.load(cfg, load_splits=False)
+
+ shuffle = meta.get(trainset_index=from_trainsetindex, index=from_shuffle)
+ shuffle = shuffle.load_split(cfg, trainset_path=trainset_meta_path.parent)
+
+ num_copies = num_shuffles
+ if shuffles is not None:
+ num_copies = len(shuffles)
+
+ # pad the train and test indices with -1s so the training fraction is exact
+ train_idx = list(shuffle.split.train_indices)
+ test_idx = list(shuffle.split.test_indices)
+ n_train, n_test = len(train_idx), len(test_idx)
+
+ train_fraction = round(cfg["TrainingFraction"][from_trainsetindex], 2)
+ if round(n_train / (n_train + n_test), 2) != train_fraction:
+ train_padding, test_padding = _compute_padding(train_fraction, n_train, n_test)
+ train_idx = train_idx + (train_padding * [-1])
+ test_idx = test_idx + (test_padding * [-1])
+
+ return create_training_dataset(
+ config=config,
+ num_shuffles=num_shuffles,
+ Shuffles=shuffles,
+ userfeedback=userfeedback,
+ trainIndices=[train_idx for _ in range(num_copies)],
+ testIndices=[test_idx for _ in range(num_copies)],
+ net_type=net_type,
+ detector_type=detector_type,
+ augmenter_type=augmenter_type,
+ posecfg_template=posecfg_template,
+ superanimal_name=superanimal_name,
+ weight_init=weight_init,
+ engine=engine,
+ )
+
+
+def _compute_padding(
+ train_fraction: float,
+ num_train: int,
+ num_test: int,
+) -> tuple[int, int]:
+ """
+ Computes the amount of padding to add to train/test indices such that
+ train_fraction = num_train / (num_train + num_test).
+
+ Returns:
+ the number of padding indices to add to the train indices
+ the number of padding indices to add to the test indices
+ """
+ if train_fraction <= 0 or train_fraction >= 1:
+ raise ValueError(
+ f"The training fraction must satisfy 0 < TrainingFraction < 1, but "
+ f"{train_fraction} was found"
+ )
+
+ base_images = 100
+ train_step = int(round(round(train_fraction, 2) * base_images))
+ test_step = base_images - train_step
+
+ tgt_train = train_step
+ tgt_test = test_step
+ while tgt_train < num_train or tgt_test < num_test:
+ tgt_train += train_step
+ tgt_test += test_step
+
+ return (tgt_train - num_train), (tgt_test - num_test)
diff --git a/deeplabcut/gui/components.py b/deeplabcut/gui/components.py
index 6fc9cfa7f5..5edc7712a4 100644
--- a/deeplabcut/gui/components.py
+++ b/deeplabcut/gui/components.py
@@ -89,6 +89,11 @@ def __init__(
self.itemSelectionChanged.connect(self.update_selected_bodyparts)
+ def refresh(self):
+ self.clear()
+ self.addItems(self.root.all_bodyparts)
+ self.update_selected_bodyparts()
+
def update_selected_bodyparts(self):
self.selected_bodyparts = [item.text() for item in self.selectedItems()]
self.root.logger.info(f"Selected bodyparts:\n\t{self.selected_bodyparts}")
@@ -179,6 +184,81 @@ def clear_selected_videos(self):
self.root.logger.info(f"Cleared selected videos")
+class SnapshotSelectionWidget(QtWidgets.QWidget):
+ def __init__(
+ self,
+ root: QtWidgets.QMainWindow,
+ parent: QtWidgets.QWidget,
+ margins: tuple,
+ select_button_text: str,
+ ):
+ super(SnapshotSelectionWidget, self).__init__(parent)
+
+ self.root = root
+ self.parent = parent
+
+ self.selected_snapshot = None
+
+ self._init_layout(margins, select_button_text)
+
+ def _init_layout(self, margins, select_button_text):
+ layout = _create_horizontal_layout(margins=margins)
+
+ # Select videos
+ self.select_snapshot_button = QtWidgets.QPushButton(select_button_text)
+ self.select_snapshot_button.setMaximumWidth(200)
+ self.select_snapshot_button.clicked.connect(self.select_snapshot)
+
+ # Selected snapshot text
+ self.selected_snapshot_text = QtWidgets.QLabel(
+ ""
+ ) # updated when snapshot is selected
+
+ # Clear snapshot selection
+ self.clear_snapshot_button = QtWidgets.QPushButton("Clear selection")
+ self.clear_snapshot_button.clicked.connect(self.clear_selected_snapshot)
+ self.clear_snapshot_button.hide()
+
+ layout.addWidget(self.select_snapshot_button)
+ layout.addWidget(self.selected_snapshot_text)
+ layout.addWidget(self.clear_snapshot_button, alignment=Qt.AlignRight)
+
+ self.setLayout(layout)
+
+ def _update_selected_snapshot_display(self):
+ if self.selected_snapshot is None:
+ self.selected_snapshot_text.setText("")
+ self.clear_snapshot_button.hide()
+ else:
+ self.selected_snapshot_text.setText(
+ f"{os.path.basename(self.selected_snapshot)}"
+ )
+ self.clear_snapshot_button.show()
+
+ def select_snapshot(self):
+ # Create a filter string with both lowercase and uppercase extensions
+ snapshot_types = ["*.pt", "*.PT"]
+ snapshot_files = f"Snapshots ({' '.join(snapshot_types)})"
+
+ directory_to_open = self.root.models_folder
+
+ selected_snapshot, _ = QtWidgets.QFileDialog.getOpenFileName(
+ self,
+ "Select snapshot to start training from",
+ directory_to_open,
+ snapshot_files,
+ )
+ # When Canceling a file selection, Qt returns an empty string as selected file
+ if selected_snapshot:
+ self.selected_snapshot = os.path.abspath(selected_snapshot)
+
+ self._update_selected_snapshot_display()
+
+ def clear_selected_snapshot(self):
+ self.selected_snapshot = None
+ self._update_selected_snapshot_display()
+
+
class TrainingSetSpinBox(QtWidgets.QSpinBox):
def __init__(self, root, parent):
super(TrainingSetSpinBox, self).__init__(parent)
@@ -201,6 +281,12 @@ def __init__(self, root, parent):
self.setMaximum(10_000)
self.setValue(self.root.shuffle_value)
self.valueChanged.connect(self.root.update_shuffle)
+ self.root.shuffle_change.connect(self.update_shuffle)
+
+ @Slot(int)
+ def update_shuffle(self, new_shuffle: int):
+ if new_shuffle != self.value():
+ self.setValue(new_shuffle)
class DefaultTab(QtWidgets.QWidget):
diff --git a/deeplabcut/gui/displays/__init__.py b/deeplabcut/gui/displays/__init__.py
new file mode 100644
index 0000000000..f511e6184a
--- /dev/null
+++ b/deeplabcut/gui/displays/__init__.py
@@ -0,0 +1,12 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+
diff --git a/deeplabcut/gui/displays/selected_shuffle_display.py b/deeplabcut/gui/displays/selected_shuffle_display.py
new file mode 100644
index 0000000000..13e6db1094
--- /dev/null
+++ b/deeplabcut/gui/displays/selected_shuffle_display.py
@@ -0,0 +1,130 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Module to display information about the selected shuffle in the GUI"""
+from __future__ import annotations
+from pathlib import Path
+
+import PySide6.QtCore as QtCore
+from PySide6 import QtWidgets
+
+from deeplabcut.core.engine import Engine
+from deeplabcut.utils import auxiliaryfunctions
+
+
+class SelectedShuffleDisplay(QtWidgets.QWidget):
+ """A widget displaying information about the selected shuffle"""
+ pose_cfg_signal = QtCore.Signal(dict)
+
+ def __init__(self, root, row_margin: int = 25):
+ super().__init__()
+ self.root = root
+
+ self._row_margin = row_margin
+
+ self._current_index: int | None = None
+ self._engine: Engine | None = None
+ self._is_top_down: bool = False
+ self._net_type: str | None = None
+ self._pose_cfg: dict | None = None
+
+ self._label = QtWidgets.QLabel("Shuffle info:")
+ self._label.setStyleSheet(f"margin: 0px 0px {self._row_margin}px 0px")
+ layout = QtWidgets.QHBoxLayout()
+ layout.addWidget(self._label)
+ self.setLayout(layout)
+
+ # initialize the display
+ self._update_display(self.root.shuffle_value)
+
+ # update the display when the shuffle or selected engine changes, or when a new
+ # shuffle has been created
+ self.root.shuffle_change.connect(self._update_display)
+ self.root.engine_change.connect(self._update_display)
+ self.root.shuffle_created.connect(self._update_display)
+
+ @property
+ def pose_cfg(self) -> dict | None:
+ return self._pose_cfg
+
+ @pose_cfg.setter
+ def pose_cfg(self, value: dict | None) -> None:
+ self._pose_cfg = value
+ self.pose_cfg_signal.emit(self._pose_cfg)
+
+ @QtCore.Slot(int)
+ def _update_display(self, new_index: int) -> None:
+ self._current_index = new_index
+
+ try:
+ pose_cfg_path = Path(self.root.pose_cfg_path)
+ except ValueError as err:
+ self._set_text_error(
+ f"Failed to read shuffle {self._current_index} - check that it exists!"
+ )
+ return
+ except ModuleNotFoundError as err:
+ # Loading a TF shuffle but TF is not installed
+ self._set_text_error(
+ f"Failed to read shuffle {self._current_index} due to error `{err}`.\n"
+ "If the error is `ModuleNotFoundError: No module named 'tensorflow'`, "
+ f"this is because\nshuffle {self._current_index} uses the tensorflow "
+ " engine, but TensorFlow is not installed in your environment.\n"
+ "Ignore this error if you'll just train PyTorch models. To train "
+ "TensorFlow models, install it with \n"
+ " Windows/Linux: pip install 'deeplabcut[tf]'\n"
+ " Apple Silicon: pip install 'deeplabcut[apple_mchips]'"
+ )
+ return
+
+ if not pose_cfg_path.exists():
+ self._set_text_error(
+ f"The model configuration file {pose_cfg_path} was not created"
+ )
+ return
+
+ self._read_pose_config(pose_cfg_path)
+ self._set_text()
+
+ def _set_text(self) -> None:
+ engine_str = "None"
+ if self._engine is not None:
+ engine_str = self._engine.aliases[0]
+
+ text = f"net type: {self._net_type} | engine: {engine_str}"
+ if self._engine == Engine.PYTORCH and self._is_top_down:
+ text += f" | top-down"
+
+ style = f"margin: 0px 0px {self._row_margin}px 0px;"
+ if self._engine != self.root.engine:
+ warning = "Change the selected Engine in the top-right to use this shuffle!"
+ text = warning + " | " + text
+ style += " color: orange;"
+
+ self._label.setStyleSheet(style)
+ self._label.setText(text)
+
+ def _set_text_error(self, error: str) -> None:
+ self._label.setText(error)
+ style = f"margin: 0px 0px {self._row_margin}px 0px; color: orange;"
+ self._label.setStyleSheet(style)
+ self.pose_cfg = None
+
+ def _read_pose_config(self, pose_cfg_path: Path) -> None:
+ pose_cfg = auxiliaryfunctions.read_plainconfig(str(pose_cfg_path))
+
+ self._engine = (
+ Engine.PYTORCH if "pytorch" in pose_cfg_path.stem.lower() else Engine.TF
+ )
+ self._net_type = pose_cfg.get("net_type", "UNKNOWN")
+ self._is_top_down = (
+ self._engine == Engine.PYTORCH and pose_cfg.get("method").lower() == "td"
+ )
+ self.pose_cfg = pose_cfg
diff --git a/deeplabcut/gui/displays/shuffle_metadata_viewer.py b/deeplabcut/gui/displays/shuffle_metadata_viewer.py
new file mode 100644
index 0000000000..b18aef85b4
--- /dev/null
+++ b/deeplabcut/gui/displays/shuffle_metadata_viewer.py
@@ -0,0 +1,63 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Widget to display existing shuffles"""
+from __future__ import annotations
+
+from PySide6 import QtWidgets
+from PySide6.QtCore import Qt
+
+import deeplabcut.generate_training_dataset.metadata as metadata
+
+
+class ShuffleMetadataViewer(QtWidgets.QDialog):
+ """Viewer for shuffle metadata"""
+
+ def __init__(self, root: QtWidgets.QMainWindow, parent: QtWidgets.QWidget):
+ super().__init__(parent)
+ self.root = root
+ self.parent = parent
+ self.file_content = _load_metadata(self.root.cfg)
+
+ self.setWindowTitle("Existing Shuffles: Metadata")
+ self.setMinimumWidth(400)
+ self.setMinimumHeight(400)
+
+ scroll = QtWidgets.QScrollArea()
+ scroll.setWidgetResizable(True)
+
+ inner_layout = QtWidgets.QVBoxLayout()
+ inner_layout.setAlignment(Qt.AlignLeft | Qt.AlignTop)
+ inner_layout.setSpacing(0)
+ inner_layout.setContentsMargins(0, 0, 0, 0)
+
+ for line in self.file_content:
+
+ inner_layout.addWidget(QtWidgets.QLabel(line))
+
+ inner = QtWidgets.QFrame(scroll)
+ inner.setLayout(inner_layout)
+ scroll.setWidget(inner)
+
+ layout = QtWidgets.QVBoxLayout()
+ layout.addWidget(scroll)
+ self.setLayout(layout)
+
+
+def _load_metadata(cfg: dict) -> list[str]:
+ metadata_path = metadata.TrainingDatasetMetadata.path(cfg)
+ if not metadata_path.exists():
+ trainset_meta = metadata.TrainingDatasetMetadata.create(cfg)
+ trainset_meta.save()
+
+ with open(metadata_path, "r") as file:
+ raw_metadata = file.read()
+
+ return raw_metadata.split("\n")
diff --git a/deeplabcut/gui/dlc_params.py b/deeplabcut/gui/dlc_params.py
index 6563682951..fce5d15a52 100644
--- a/deeplabcut/gui/dlc_params.py
+++ b/deeplabcut/gui/dlc_params.py
@@ -31,8 +31,6 @@ class DLCParams:
"efficientnet-b6",
]
- IMAGE_AUGMENTERS = ["default", "tensorpack", "imgaug"]
-
FRAME_EXTRACTION_ALGORITHMS = ["kmeans", "uniform"]
OUTLIER_EXTRACTION_ALGORITHMS = ["jump", "fitting", "uncertain", "manual"]
diff --git a/deeplabcut/gui/media/dlc-pt.png b/deeplabcut/gui/media/dlc-pt.png
new file mode 100644
index 0000000000..d0ac99c187
Binary files /dev/null and b/deeplabcut/gui/media/dlc-pt.png differ
diff --git a/deeplabcut/gui/media/dlc-tf.png b/deeplabcut/gui/media/dlc-tf.png
new file mode 100644
index 0000000000..79d06f0528
Binary files /dev/null and b/deeplabcut/gui/media/dlc-tf.png differ
diff --git a/deeplabcut/gui/tabs/analyze_videos.py b/deeplabcut/gui/tabs/analyze_videos.py
index e608839701..60971f4c3e 100644
--- a/deeplabcut/gui/tabs/analyze_videos.py
+++ b/deeplabcut/gui/tabs/analyze_videos.py
@@ -27,6 +27,7 @@
import deeplabcut
from deeplabcut.utils.auxiliaryfunctions import edit_config
+from deeplabcut.utils import auxfun_multianimal
class AnalyzeVideos(DefaultTab):
@@ -198,36 +199,36 @@ def _generate_layout_multianimal(self, layout):
layout.addLayout(tmp_layout)
def update_create_video_detections(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Create video with all detections {s}")
def update_assemble_with_ID_only(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Assembly with ID only {s}")
def update_calibrate_assembly(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Assembly calibration {s}")
def update_tracker_type(self, method):
self.root.logger.info(f"Using {method.upper()} tracker")
def update_csv_choice(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Save results as CSV {s}")
def update_filter_choice(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Filtering predictions {s}")
def update_showfigs_choice(self, state):
- if state == Qt.Checked:
+ if Qt.CheckState(state) == Qt.Checked:
self.root.logger.info("Plots will show as pop ups.")
else:
self.root.logger.info("Plots will not show up.")
def update_crop_choice(self, state):
- if state == Qt.Checked:
+ if Qt.CheckState(state) == Qt.Checked:
self.root.logger.info("Dynamic bodypart cropping ENABLED.")
self.dynamic_cropping = True
else:
@@ -235,7 +236,8 @@ def update_crop_choice(self, state):
self.dynamic_cropping = False
def update_plot_trajectory_choice(self, state):
- if state == Qt.Checked:
+ if Qt.CheckState(state) == Qt.Checked:
+ self.bodyparts_list_widget.refresh()
self.bodyparts_list_widget.show()
self.bodyparts_list_widget.setEnabled(True)
self.show_trajectory_plots.setEnabled(True)
@@ -263,12 +265,8 @@ def analyze_videos(self):
videotype = self.video_selection_widget.videotype_widget.currentText()
if self.root.is_multianimal:
- calibrate_assembly = (
- self.calibrate_assembly_checkbox.isChecked()
- )
- assemble_with_ID_only = (
- self.assemble_with_ID_only_checkbox.isChecked()
- )
+ calibrate_assembly = self.calibrate_assembly_checkbox.isChecked()
+ assemble_with_ID_only = self.assemble_with_ID_only_checkbox.isChecked()
track_method = self.tracker_type_widget.currentText()
edit_config(self.root.config, {"default_track_method": track_method})
num_animals_in_videos = self.num_animals_in_videos.value()
@@ -339,6 +337,7 @@ def run_enabled(self):
shuffle=shuffle,
)
+ track_method = auxfun_multianimal.get_track_method(self.root.cfg)
if filter_data:
deeplabcut.filterpredictions(
config,
@@ -348,6 +347,7 @@ def run_enabled(self):
filtertype="median",
windowlength=5,
save_as_csv=save_as_csv,
+ track_method=track_method,
)
if self.plot_trajectories.isChecked():
@@ -355,7 +355,6 @@ def run_enabled(self):
self.root.logger.debug(
f"Selected body parts for plot_trajectories: {bdpts}"
)
- showfig = self.show_trajectory_plots.isChecked()
deeplabcut.plot_trajectories(
config,
videos=videos,
@@ -363,7 +362,8 @@ def run_enabled(self):
videotype=videotype,
shuffle=shuffle,
filtered=filter_data,
- showfigures=showfig,
+ showfigures=self.show_trajectory_plots.isChecked(),
+ track_method=track_method,
)
if self.root.is_multianimal and save_as_csv:
diff --git a/deeplabcut/gui/tabs/create_project.py b/deeplabcut/gui/tabs/create_project.py
index 6eaf66fa38..95331613b2 100644
--- a/deeplabcut/gui/tabs/create_project.py
+++ b/deeplabcut/gui/tabs/create_project.py
@@ -4,24 +4,187 @@
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
import os
from datetime import datetime
+from PySide6 import QtCore, QtWidgets
+from PySide6.QtGui import QBrush, QColor, QDesktopServices, QIcon, QPainter, QPen
+
import deeplabcut
-from deeplabcut.utils import auxiliaryfunctions
from deeplabcut.gui import BASE_DIR
from deeplabcut.gui.dlc_params import DLCParams
from deeplabcut.gui.widgets import ClickableLabel, ItemSelectionFrame
+from deeplabcut.gui.tabs.docs import (
+ URL_3D,
+ URL_MA_CONFIGURE,
+ URL_USE_GUIDE_SCENARIO,
+)
+from deeplabcut.utils import auxiliaryfunctions
-from PySide6 import QtCore, QtWidgets
-from PySide6.QtGui import QIcon
+
+class DynamicTextList(QtWidgets.QWidget):
+ """Dynamically add text entries"""
+
+ def __init__(self, label_text="bodyparts", parent=None):
+ super(DynamicTextList, self).__init__(parent)
+ self.label_text = label_text
+ self.layout = QtWidgets.QVBoxLayout(self)
+ self.layout.setContentsMargins(0, 0, 0, 0)
+
+ # Set maximum width for the widget
+ self.setMaximumWidth(300)
+
+ # Add explanatory label
+ label = QtWidgets.QLabel(label_text)
+ self.layout.addWidget(label)
+
+ # Create scroll area and its widget
+ self.scroll = QtWidgets.QScrollArea()
+ self.scroll.setWidgetResizable(True)
+ self.scroll.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
+ self.scroll.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded)
+ self.scroll.setFrameShape(QtWidgets.QFrame.NoFrame) # Remove frame border
+
+ # Create widget to hold the entries
+ self.entries_widget = QtWidgets.QWidget()
+ self.entries_layout = QtWidgets.QVBoxLayout(self.entries_widget)
+ self.entries_layout.setContentsMargins(0, 0, 0, 0)
+ self.entries_layout.setSpacing(5) # Consistent spacing between entries
+ self.entries_layout.setAlignment(QtCore.Qt.AlignTop) # Align entries to top
+
+ # Add stretch at the bottom to keep entries at top
+ self.entries_layout.addStretch()
+
+ self.scroll.setWidget(self.entries_widget)
+
+ # Set fixed height for 6 items
+ self.entry_height = 30 # Fixed height for each entry
+ self.padding = 10 # Extra padding
+ self.scroll.setFixedHeight(5 * self.entry_height + self.padding)
+
+ # Add scroll area to main layout
+ self.layout.addWidget(self.scroll)
+
+ self.entries = []
+ self.add_entry()
+
+ def add_entry(self):
+ # Create horizontal layout for index and entry
+ entry_layout = QtWidgets.QHBoxLayout()
+ entry_layout.setContentsMargins(0, 0, 10, 0)
+ entry_layout.setSpacing(5) # Consistent spacing between index and entry
+
+ # Create container widget for the entry row
+ entry_widget = QtWidgets.QWidget()
+ entry_widget.setFixedHeight(self.entry_height)
+ entry_widget.setLayout(entry_layout)
+
+ # Add index label
+ index_label = QtWidgets.QLabel(str(len(self.entries) + 1) + ".")
+ index_label.setFixedWidth(20) # Set fixed width for alignment
+ entry_layout.addWidget(index_label)
+
+ # Add text entry
+ entry = QtWidgets.QLineEdit()
+ entry.setFixedHeight(self.entry_height - 6) # Slightly smaller than container
+ entry.textChanged.connect(self._on_text_changed)
+ entry.textEdited.connect(lambda text: self._check_for_spaces(entry, text))
+ self.entries.append((entry, index_label)) # Store both widgets
+ entry_layout.addWidget(entry)
+
+ # Insert the new entry before the stretch
+ self.entries_layout.insertWidget(len(self.entries) - 1, entry_widget)
+
+ def _check_for_spaces(self, entry, text):
+ if " " in text:
+ msg = QtWidgets.QMessageBox()
+ msg.setIcon(QtWidgets.QMessageBox.Warning)
+ msg.setText(
+ f"Spaces are not allowed in the {self.label_text} list. Use underscores "
+ f"instead."
+ )
+ msg.setWindowTitle("Warning")
+ msg.exec_()
+ entry.setText(entry.text().replace(" ", "_"))
+
+ def _on_text_changed(self):
+ # If the last entry has text, add a new empty entry
+ if self.entries[-1][0].text():
+ self.add_entry()
+
+ # Remove any empty entries except the last one
+ entries_to_remove = []
+ for i, (entry, _) in enumerate(self.entries[:-1]):
+ if not entry.text():
+ entries_to_remove.append(i)
+
+ for i in reversed(entries_to_remove):
+ entry_widget = self.entries[i][0].parent()
+ self.entries_layout.removeWidget(entry_widget)
+ entry_widget.deleteLater()
+ self.entries.pop(i)
+
+ self._update_indices() # Update the indices after removal
+
+ def get_entries(self):
+ return [entry[0].text() for entry in self.entries if entry[0].text()]
+
+ def _update_indices(self):
+ for i, (entry, index_label) in enumerate(self.entries):
+ index_label.setText(str(i + 1) + ".")
+
+
+class Switch(QtWidgets.QPushButton):
+
+ def __init__(self, on_text="Yes", off_text="No", width=80, parent=None):
+ super().__init__(parent)
+ self.on_text = on_text
+ self.off_text = off_text
+ self.setCheckable(True)
+ self.setFixedWidth(width)
+ self.setMinimumHeight(22)
+
+ def paintEvent(self, event):
+ # Colors: https://qdarkstylesheet.readthedocs.io/en/latest/color_reference.html
+ label = self.on_text if self.isChecked() else self.off_text
+ bg_color = "#00ff00" if self.isChecked() else "#9DA9B5"
+
+ radius = 10
+ width = 32
+ center = self.rect().center()
+
+ painter = QPainter(self)
+ painter.setRenderHint(QPainter.Antialiasing)
+ painter.translate(center)
+ painter.setBrush(QColor(69, 83, 100)) # Lighter gray background
+
+ pen = QPen("#455364")
+ pen.setWidth(2)
+ painter.setPen(pen)
+
+ painter.drawRoundedRect(
+ QtCore.QRect(-width, -radius, 2 * width, 2 * radius), radius, radius
+ )
+ painter.setBrush(QBrush(bg_color))
+ sw_rect = QtCore.QRect(-radius, -radius, width + radius, 2 * radius)
+ if not self.isChecked():
+ sw_rect.moveLeft(-width)
+
+ painter.drawRoundedRect(sw_rect, radius, radius)
+
+ pen = QPen("#000000")
+ pen.setWidth(2)
+ painter.setPen(pen)
+ painter.drawText(sw_rect, QtCore.Qt.AlignCenter, label)
class ProjectCreator(QtWidgets.QDialog):
+ """Project creation dialog"""
+
def __init__(self, parent):
super(ProjectCreator, self).__init__(parent)
self.parent = parent
@@ -34,6 +197,19 @@ def __init__(self, parent):
self.exp_default = ""
self.loc_default = parent.project_folder
+ self.bodypart_list = None
+ self.individuals_list = None
+ self.unique_bodyparts_list = None
+
+ self.toggle_3d = Switch()
+ self.toggle_3d.setChecked(False)
+ self.madlc_toggle = Switch()
+ self.madlc_toggle.setChecked(False)
+ self.unique_toggle = Switch()
+ self.unique_toggle.setChecked(False)
+ self.identity_toggle = Switch()
+ self.identity_toggle.setChecked(False)
+
main_layout = QtWidgets.QVBoxLayout(self)
self.user_frame = self.lay_out_user_frame()
self.video_frame = self.lay_out_video_frame()
@@ -80,46 +256,172 @@ def lay_out_user_frame(self):
grid.addWidget(self.loc_line, 2, 1)
vbox.addLayout(grid)
- self.madlc_box = QtWidgets.QCheckBox("Is it a multi-animal project?")
- self.madlc_box.setChecked(False)
- vbox.addWidget(self.madlc_box)
+ widget_3d = self.build_toggle_widget(
+ switch=self.toggle_3d,
+ question="Do you want to create a 3D pose estimation project?",
+ help_text="(What is needed for a 3D project?)",
+ docs_link=URL_3D,
+ )
+ madlc_widget = self.build_toggle_widget(
+ switch=self.madlc_toggle,
+ question="Are there multiple individuals in your videos?",
+ help_text="(Why does this matter?)",
+ docs_link=URL_USE_GUIDE_SCENARIO,
+ )
+
+ # Only visible when the maDLC widget is checked
+ unique_widget = self.build_toggle_widget(
+ switch=self.unique_toggle,
+ question="Do you have unique bodyparts in your video?",
+ help_text="(What are unique bodyparts?)",
+ docs_link=URL_MA_CONFIGURE,
+ )
+ unique_widget.setVisible(False)
+
+ # Labelling with identity
+ identity_widget = self.build_toggle_widget(
+ switch=self.identity_toggle,
+ question="Label with identity?",
+ help_text="(What is labeling with identity?)",
+ docs_link=URL_MA_CONFIGURE,
+ )
+ identity_widget.setVisible(False)
+
+ vbox.addWidget(widget_3d, alignment=QtCore.Qt.AlignTop)
+ vbox.addWidget(madlc_widget, alignment=QtCore.Qt.AlignTop)
+ vbox.addWidget(unique_widget, alignment=QtCore.Qt.AlignTop)
+ vbox.addWidget(identity_widget, alignment=QtCore.Qt.AlignTop)
+
+ # Create horizontal layout for the two lists
+ lists_layout = QtWidgets.QHBoxLayout()
+ lists_layout.setAlignment(QtCore.Qt.AlignTop)
+
+ # Create both DynamicTextList widgets as class attributes
+ self.bodypart_list = DynamicTextList(
+ label_text="Bodyparts to track",
+ parent=self,
+ )
+
+ self.individuals_list = DynamicTextList(
+ label_text="Individual names",
+ parent=self,
+ )
+ self.individuals_list.setVisible(False)
+
+ self.unique_bodyparts_list = DynamicTextList(
+ label_text="Unique bodyparts to track",
+ parent=self,
+ )
+ self.unique_bodyparts_list.setVisible(False)
+
+ # Connect toggle state to individuals list visibility, unique, identity
+ self.madlc_toggle.toggled.connect(self.individuals_list.setVisible)
+ self.madlc_toggle.toggled.connect(unique_widget.setVisible)
+ self.madlc_toggle.toggled.connect(identity_widget.setVisible)
+
+ # Connect the unique_toggle to the unique_bodyparts_list
+ self.unique_toggle.toggled.connect(
+ lambda yes: self.unique_bodyparts_list.setVisible(
+ yes and self.madlc_toggle.isChecked()
+ )
+ )
+
+ # Connect 3d toggle to all other option visibility
+ self.toggle_3d.toggled.connect(lambda yes: madlc_widget.setVisible(not yes))
+ self.toggle_3d.toggled.connect(
+ lambda checked_3d: unique_widget.setVisible(
+ not checked_3d and self.madlc_toggle.isChecked()
+ )
+ )
+ self.toggle_3d.toggled.connect(
+ lambda checked_3d: identity_widget.setVisible(
+ not checked_3d and self.madlc_toggle.isChecked()
+ )
+ )
+ self.toggle_3d.toggled.connect(
+ lambda checked_3d: self.bodypart_list.setVisible(not checked_3d)
+ )
+ self.toggle_3d.toggled.connect(
+ lambda checked_3d: self.individuals_list.setVisible(
+ not checked_3d and self.madlc_toggle.isChecked()
+ )
+ )
+ self.toggle_3d.toggled.connect(
+ lambda checked_3d: self.unique_bodyparts_list.setVisible(
+ not checked_3d
+ and self.madlc_toggle.isChecked()
+ and self.unique_toggle.isChecked()
+ )
+ )
+ # Add both lists to the horizontal layout with top alignment
+ lists_layout.addWidget(self.bodypart_list, alignment=QtCore.Qt.AlignTop)
+ lists_layout.addWidget(self.individuals_list, alignment=QtCore.Qt.AlignTop)
+ lists_layout.addWidget(self.unique_bodyparts_list, alignment=QtCore.Qt.AlignTop)
+
+ # Add the horizontal layout to the main vertical layout
+ vbox.addLayout(lists_layout)
return user_frame
+ def build_toggle_widget(
+ self,
+ switch: Switch,
+ question: str,
+ help_text: str,
+ docs_link: str,
+ ) -> QtWidgets.QWidget:
+ toggle_layout = QtWidgets.QHBoxLayout()
+ toggle_layout.setContentsMargins(0, 0, 0, 0)
+ toggle_layout.setSpacing(10)
+
+ toggle_label = QtWidgets.QLabel(question)
+ toggle_label.setAlignment(QtCore.Qt.AlignLeft)
+ help_label = ClickableLabel(help_text, parent=self)
+ help_label.setStyleSheet("text-decoration: underline; font-weight: bold;")
+ help_label.setCursor(QtCore.Qt.PointingHandCursor)
+ help_label.signal.connect(
+ lambda: QDesktopServices.openUrl(QtCore.QUrl(docs_link))
+ )
+
+ toggle_layout.addWidget(switch, alignment=QtCore.Qt.AlignLeft)
+ toggle_layout.addWidget(toggle_label, alignment=QtCore.Qt.AlignLeft)
+ toggle_layout.addStretch()
+ toggle_layout.addWidget(help_label, alignment=QtCore.Qt.AlignRight)
+ toggle_widget = QtWidgets.QWidget()
+ toggle_widget.setLayout(toggle_layout)
+ return toggle_widget
+
def lay_out_video_frame(self):
video_frame = ItemSelectionFrame([], self)
- self.cam_combo = QtWidgets.QComboBox(video_frame)
- self.cam_combo.addItems(map(str, (1, 2)))
- self.cam_combo.currentTextChanged.connect(self.check_num_cameras)
- ncam_label = QtWidgets.QLabel("Number of cameras:")
- ncam_label.setBuddy(self.cam_combo)
-
self.copy_box = QtWidgets.QCheckBox("Copy videos to project folder")
self.copy_box.setChecked(False)
- browse_button = QtWidgets.QPushButton("Browse videos")
+ browse_button = QtWidgets.QPushButton("Browse folders for videos")
browse_button.clicked.connect(self.browse_videos)
clear_button = QtWidgets.QPushButton("Clear")
clear_button.clicked.connect(video_frame.fancy_list.clear)
- layout1 = QtWidgets.QHBoxLayout()
- layout1.addWidget(ncam_label)
- layout1.addWidget(self.cam_combo)
- layout2 = QtWidgets.QHBoxLayout()
- layout2.addWidget(browse_button)
- layout2.addWidget(clear_button)
- video_frame.layout.insertLayout(0, layout1)
- video_frame.layout.addLayout(layout2)
+ layout = QtWidgets.QHBoxLayout()
+ layout.addWidget(browse_button)
+ layout.addWidget(clear_button)
+ video_frame.layout.addLayout(layout)
video_frame.layout.addWidget(self.copy_box)
+ self.toggle_3d.toggled.connect(lambda yes: self.copy_box.setVisible(not yes))
+ self.toggle_3d.toggled.connect(lambda yes: browse_button.setVisible(not yes))
+ self.toggle_3d.toggled.connect(lambda yes: clear_button.setVisible(not yes))
+ self.toggle_3d.toggled.connect(lambda yes: video_frame.setVisible(not yes))
return video_frame
def browse_videos(self):
+ options = QtWidgets.QFileDialog.Options()
+ options |= QtWidgets.QFileDialog.DontUseNativeDialog
folder = QtWidgets.QFileDialog.getExistingDirectory(
self,
"Please select a folder",
self.loc_default,
+ options,
)
if not folder:
return
@@ -128,7 +430,7 @@ def browse_videos(self):
folder,
relative=False,
):
- if os.path.splitext(video)[1][1:] in DLCParams.VIDEOTYPES[1:]:
+ if os.path.splitext(video)[1][1:].lower() in DLCParams.VIDEOTYPES[1:]:
self.video_frame.fancy_list.add_item(video)
def finalize_project(self):
@@ -142,13 +444,13 @@ def finalize_project(self):
if empty:
return
- n_cameras = int(self.cam_combo.currentText())
+ create_3d = self.toggle_3d.isChecked()
try:
- if n_cameras > 1:
+ if create_3d:
_ = deeplabcut.create_new_project_3d(
self.proj_default,
self.exp_default,
- n_cameras,
+ 2,
self.loc_default,
)
else:
@@ -162,7 +464,7 @@ def finalize_project(self):
self.video_frame.fancy_list._default_style
)
to_copy = self.copy_box.isChecked()
- is_madlc = self.madlc_box.isChecked()
+ is_madlc = self.madlc_toggle.isChecked()
config = deeplabcut.create_new_project(
self.proj_default,
self.exp_default,
@@ -171,16 +473,44 @@ def finalize_project(self):
to_copy,
multianimal=is_madlc,
)
+
+ if self.bodypart_list is not None:
+ bodypart_key = "bodyparts"
+ updates = {}
+ if is_madlc:
+ bodypart_key = "multianimalbodyparts"
+ if self.individuals_list is not None:
+ individuals = self.individuals_list.get_entries()
+ if len(individuals) > 0:
+ updates["individuals"] = individuals
+
+ if (
+ self.unique_toggle.isChecked()
+ and self.unique_bodyparts_list is not None
+ ):
+ unique_bodyparts = self.unique_bodyparts_list.get_entries()
+ if len(unique_bodyparts) > 0:
+ updates["uniquebodyparts"] = unique_bodyparts
+
+ if self.identity_toggle.isChecked():
+ updates["identity"] = True
+
+ bodyparts = self.bodypart_list.get_entries()
+ if len(bodyparts) > 0:
+ updates[bodypart_key] = bodyparts
+
+ if len(updates) > 0:
+ cfg: dict = auxiliaryfunctions.read_config(config)
+ cfg.update(**updates)
+ auxiliaryfunctions.write_config(config, cfg)
+
self.parent.load_config(config)
- self.parent._update_project_state(
- config=config,
- loaded=True,
- )
+ self.parent._update_project_state(config=config, loaded=True)
except FileExistsError:
print('Project "{}" already exists!'.format(self.proj_default))
return
- msg = QtWidgets.QMessageBox(text=f"New project created")
+ msg = QtWidgets.QMessageBox(text="New project created")
msg.setIcon(QtWidgets.QMessageBox.Information)
msg.exec_()
@@ -195,15 +525,6 @@ def on_click(self):
self.loc_default = dirname
self.update_project_location()
- def check_num_cameras(self, value):
- val = int(value)
- for child in self.video_frame.children():
- if child.isWidgetType() and not isinstance(child, QtWidgets.QComboBox):
- if val > 1:
- child.setDisabled(True)
- else:
- child.setDisabled(False)
-
def update_project_name(self, text):
self.proj_default = text
self.update_project_location()
diff --git a/deeplabcut/gui/tabs/create_training_dataset.py b/deeplabcut/gui/tabs/create_training_dataset.py
index 7101b0e5b3..15c06418c3 100644
--- a/deeplabcut/gui/tabs/create_training_dataset.py
+++ b/deeplabcut/gui/tabs/create_training_dataset.py
@@ -8,12 +8,23 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+from __future__ import annotations
+
import os
+from pathlib import Path
+import dlclibrary
from PySide6 import QtWidgets
-from PySide6.QtCore import Qt
+from PySide6.QtCore import Qt, Slot
from PySide6.QtGui import QIcon
+import deeplabcut
+import deeplabcut.compat as compat
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.generate_training_dataset import get_existing_shuffle_indices
+from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine
+from deeplabcut.gui.displays.shuffle_metadata_viewer import ShuffleMetadataViewer
from deeplabcut.gui.dlc_params import DLCParams
from deeplabcut.gui.components import (
DefaultTab,
@@ -21,8 +32,8 @@
_create_grid_layout,
_create_label_widget,
)
-
-import deeplabcut
+from deeplabcut.gui.widgets import launch_napari
+from deeplabcut.modelzoo import build_weight_init
from deeplabcut.utils.auxiliaryfunctions import (
get_data_and_metadata_filenames,
get_training_set_folder,
@@ -40,16 +51,36 @@ def __init__(self, root, parent, h1_description):
self._generate_layout_attributes(self.layout_attributes)
self.main_layout.addLayout(self.layout_attributes)
+ self.mapping_button = QtWidgets.QPushButton("Edit Conversion Table")
+ self.mapping_button.clicked.connect(self.edit_conversion_table)
+ self.mapping_button.setVisible(False)
+ self.root.engine_change.connect(self.set_edit_table_visibility)
+
self.ok_button = QtWidgets.QPushButton("Create Training Dataset")
self.ok_button.setMinimumWidth(150)
self.ok_button.clicked.connect(self.create_training_dataset)
+ self.main_layout.addWidget(self.mapping_button, alignment=Qt.AlignRight)
self.main_layout.addWidget(self.ok_button, alignment=Qt.AlignRight)
+ self.view_shuffles_button = QtWidgets.QPushButton("View Existing Shuffles")
+ self.view_shuffles_button.clicked.connect(self.view_shuffles)
+ self.main_layout.addWidget(self.view_shuffles_button, alignment=Qt.AlignLeft)
+
self.help_button = QtWidgets.QPushButton("Help")
self.help_button.clicked.connect(self.show_help_dialog)
self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft)
+ def set_edit_table_visibility(self) -> None:
+ has_conversion_tables = bool(
+ self.root.cfg.get("SuperAnimalConversionTables", {})
+ )
+ is_pytorch_engine = self.root.engine == Engine.PYTORCH
+ is_finetuning = self.weight_init_selector.with_decoder
+ self.mapping_button.setVisible(
+ has_conversion_tables & is_pytorch_engine & is_finetuning
+ )
+
def show_help_dialog(self):
dialog = QtWidgets.QDialog(self)
layout = QtWidgets.QVBoxLayout()
@@ -74,29 +105,77 @@ def _generate_layout_attributes(self, layout):
shuffle_label = QtWidgets.QLabel("Shuffle")
self.shuffle = ShuffleSpinBox(root=self.root, parent=self)
+ # Dataset choices
+ self.weight_init_label = QtWidgets.QLabel("Weight Initialization")
+ self.weight_init_selector = WeightInitializationSelector(self.root)
+ self.update_weight_init_methods(self.root.engine)
+ self.root.engine_change.connect(self.update_weight_init_methods)
+
# Augmentation method
augmentation_label = QtWidgets.QLabel("Augmentation method")
self.aug_choice = QtWidgets.QComboBox()
- self.aug_choice.addItems(DLCParams.IMAGE_AUGMENTERS)
- self.aug_choice.setCurrentText("imgaug")
+ self.update_aug_methods(self.root.engine)
+ self.root.engine_change.connect(self.update_aug_methods)
self.aug_choice.currentTextChanged.connect(self.log_augmentation_choice)
# Neural Network
nnet_label = QtWidgets.QLabel("Network architecture")
self.net_choice = QtWidgets.QComboBox()
- nets = DLCParams.NNETS.copy()
- if not self.root.is_multianimal:
- nets.remove("dlcrnet_ms5")
- self.net_choice.addItems(nets)
- self.net_choice.setCurrentText("resnet_50")
+ self.net_choice.setMinimumWidth(200)
+ self.update_nets(self.root.engine)
+ self.root.engine_change.connect(self.update_nets)
self.net_choice.currentTextChanged.connect(self.log_net_choice)
+ # Update Net types when selected weight init changes
+ self.weight_init_selector.weight_init_choice.currentTextChanged.connect(
+ lambda _: self.update_nets(None)
+ )
+ self.weight_init_selector.weight_init_choice.currentTextChanged.connect(
+ lambda _: self.set_edit_table_visibility()
+ )
+
+ # Detector selection for top-down models
+ self.detector_label = QtWidgets.QLabel("Detector architecture")
+ self.detector_choice = QtWidgets.QComboBox()
+ self.detector_choice.setMinimumWidth(200)
+ self.update_detectors(engine=self.root.engine)
+ self.root.engine_change.connect(
+ lambda engine: self.update_detectors(engine=engine)
+ )
+ self.net_choice.currentTextChanged.connect(
+ lambda new_net_choice: self.update_detectors(net_choice=new_net_choice)
+ )
+
+ # Overwrite selection
+ self.overwrite = QtWidgets.QCheckBox("Overwrite if exists")
+ self.overwrite.setChecked(False)
+ self.overwrite.setToolTip(
+ "When checked, creating a new shuffle with an index that already exists "
+ "will overwrite the existing index. Be careful with this option as you "
+ "might lose data."
+ )
+ self.overwrite.stateChanged.connect(
+ lambda s: self.root.logger.info(f"Overwrite: {s}")
+ )
+
+ # Use same data split as another shuffle
+ self.data_split_selection = DataSplitSelector(self.root, self)
+
layout.addWidget(shuffle_label, 0, 0)
layout.addWidget(self.shuffle, 0, 1)
- layout.addWidget(nnet_label, 0, 2)
- layout.addWidget(self.net_choice, 0, 3)
- layout.addWidget(augmentation_label, 0, 4)
- layout.addWidget(self.aug_choice, 0, 5)
+ layout.addWidget(self.weight_init_label, 0, 2)
+ layout.addWidget(self.weight_init_selector, 0, 3)
+
+ layout.addWidget(nnet_label, 1, 0)
+ layout.addWidget(self.net_choice, 1, 1)
+ layout.addWidget(augmentation_label, 1, 2)
+ layout.addWidget(self.aug_choice, 1, 3)
+
+ layout.addWidget(self.detector_label, 2, 0)
+ layout.addWidget(self.detector_choice, 2, 1)
+
+ layout.addWidget(self.overwrite, 3, 0)
+ layout.addWidget(self.data_split_selection, 4, 0)
def log_net_choice(self, net):
self.root.logger.info(f"Network architecture set to {net.upper()}")
@@ -104,34 +183,134 @@ def log_net_choice(self, net):
def log_augmentation_choice(self, augmentation):
self.root.logger.info(f"Image augmentation set to {augmentation.upper()}")
+ def edit_conversion_table(self):
+ # Test beforehand whether a conversion table exists
+ memory_replay_folder = Path(self.root.project_folder) / "memory_replay"
+ conversion_matrix_out_path = str(memory_replay_folder / "confusion_matrix.png")
+ files = [self.root.config]
+ if os.path.exists(conversion_matrix_out_path):
+ files.append(conversion_matrix_out_path)
+ _ = launch_napari(files)
+
def create_training_dataset(self):
shuffle = self.shuffle.value()
+ cfg = self.root.cfg
+ existing_indices = get_existing_shuffle_indices(
+ cfg=cfg, train_fraction=cfg["TrainingFraction"][self.root.trainingset_index]
+ )
+
+ overwrite = self.overwrite.isChecked()
+ if shuffle in existing_indices:
+ if overwrite:
+ if not self._confirm_overwrite(shuffle, existing_indices):
+ return
+ else:
+ msg = _create_message_box(
+ f"The training dataset could not be created.",
+ (
+ f"Shuffle {shuffle} already exists - you can create a new "
+ "training dataset with an unused shuffle index (existing "
+ f"shuffles are {existing_indices}) or you can overwrite the "
+ f"shuffle by ticking the 'Overwrite' checkbox"
+ ),
+ )
+ msg.exec_()
+ self.root.writer.write("Training dataset creation failed.")
+ return
if self.model_comparison:
raise NotImplementedError
# TODO: finish model_comparison
- deeplabcut.create_training_model_comparison(
- config_file,
- num_shuffles=shuffle,
- net_types=self.net_type,
- augmenter_types=self.aug_type,
- )
+ # deeplabcut.create_training_model_comparison(
+ # config_file,
+ # num_shuffles=shuffle,
+ # net_types=self.net_type,
+ # augmenter_types=self.aug_type,
+ # )
else:
- if self.root.is_multianimal:
- deeplabcut.create_multianimaltraining_dataset(
- self.root.config,
- shuffle,
- Shuffles=[self.shuffle.value()],
- net_type=self.net_choice.currentText(),
+ try:
+ engine = self.root.engine
+ net_type = self.net_choice.currentText()
+ detector_type = None
+ if engine == Engine.TF:
+ import tensorflow
+
+ # try importing TF so they can't create shuffles for it if they
+ # don't have it installed
+ elif engine == Engine.PYTORCH and "top_down" in net_type:
+ detector_type = self.detector_choice.currentText()
+
+ try:
+ weight_init = (
+ self.weight_init_selector.get_super_animal_weight_init(
+ net_type,
+ detector_type,
+ )
+ )
+ except ValueError as err:
+ print(f"The training dataset could not be created: {err}.")
+ return
+
+ if self.data_split_selection.selected:
+ deeplabcut.create_training_dataset_from_existing_split(
+ self.root.config,
+ from_shuffle=self.data_split_selection.from_shuffle,
+ shuffles=[self.shuffle.value()],
+ net_type=net_type,
+ detector_type=detector_type,
+ userfeedback=not overwrite,
+ weight_init=weight_init,
+ engine=engine,
+ )
+
+ elif self.root.is_multianimal:
+ deeplabcut.create_multianimaltraining_dataset(
+ self.root.config,
+ shuffle,
+ Shuffles=[self.shuffle.value()],
+ net_type=net_type,
+ detector_type=detector_type,
+ userfeedback=not overwrite,
+ weight_init=weight_init,
+ engine=engine,
+ )
+ else:
+ deeplabcut.create_training_dataset(
+ self.root.config,
+ shuffle,
+ Shuffles=[self.shuffle.value()],
+ net_type=net_type,
+ detector_type=detector_type,
+ augmenter_type=self.aug_choice.currentText(),
+ userfeedback=not overwrite,
+ weight_init=weight_init,
+ engine=engine,
+ )
+ except ValueError as err:
+ msg = _create_message_box(
+ f"The training dataset could not be created.",
+ str(err),
)
- else:
- deeplabcut.create_training_dataset(
- self.root.config,
- shuffle,
- Shuffles=[self.shuffle.value()],
- net_type=self.net_choice.currentText(),
- augmenter_type=self.aug_choice.currentText(),
+ msg.exec_()
+ return
+ except ModuleNotFoundError as err:
+ info_text = (
+ f"Error `{err}`. If the error is `ModuleNotFoundError: No module "
+ "named 'tensorflow'`, this is because you tried creating a "
+ "TensorFlow shuffle, but TensorFlow is not installed in your "
+ "environment. To create TensorFlow shuffles (and use TensorFlow "
+ "models), install it with\n"
+ " Windows/Linux:\n"
+ " pip install 'deeplabcut[tf]'\n"
+ " Apple Silicon:\n"
+ " pip install 'deeplabcut[apple_mchips]'"
)
+ msg = _create_message_box(
+ f"The training dataset could not be created.", info_text
+ )
+ msg.exec_()
+ return
+
# Check that training data files were indeed created.
trainingsetfolder = get_training_set_folder(self.root.cfg)
filenames = list(
@@ -148,6 +327,7 @@ def create_training_dataset(self):
os.path.exists(os.path.join(self.root.project_folder, file))
for file in filenames
):
+ self.root.shuffle_created.emit(self.shuffle.value())
msg = _create_message_box(
"The training dataset is successfully created.",
"Use the function 'train_network' to start training. Happy training!",
@@ -162,6 +342,369 @@ def create_training_dataset(self):
msg.exec_()
self.root.writer.write("Training dataset creation failed.")
+ def _confirm_overwrite(self, shuffle: int, existing_indices: list[int]) -> bool:
+ """
+ Asks the user to confirm that they want to overwrite a shuffle.
+
+ Args:
+ shuffle: the shuffle the user wants to overwrite
+ existing_indices: the indices of existing shuffles
+
+ Returns:
+ whether the user confirmed overwriting the shuffle
+ """
+ try:
+ engine = get_shuffle_engine(
+ self.root.cfg, self.root.trainingset_index, shuffle
+ )
+ engine_str = f" (with engine '{engine.aliases[0]}')"
+ except ValueError:
+ engine_str = ""
+
+ conf = _create_confirmation_box(
+ title=f"Are you sure you want to overwrite shuffle {shuffle}?",
+ description=(
+ f"As shuffle {shuffle} already exists{engine_str}, "
+ f"the training-dataset files would be overwritten."
+ ),
+ )
+ result = conf.exec()
+ if result != QtWidgets.QMessageBox.Yes:
+ msg = _create_message_box(
+ text="The training dataset was not be created.",
+ info_text=(
+ "You can create a shuffle with another index. Existing indices "
+ f"are {existing_indices}"
+ ),
+ )
+ msg.exec_()
+ self.root.writer.write("Training dataset creation interrupted.")
+ return False
+
+ return True
+
+ @Slot(Engine)
+ def update_nets(self, engine: Engine | None) -> None:
+ if engine is None:
+ engine = self.root.engine
+
+ default_net = None
+ if engine == Engine.TF:
+ nets = DLCParams.NNETS.copy()
+ if not self.root.is_multianimal:
+ nets.remove("dlcrnet_ms5")
+ else:
+ # FIXME: Circular imports make it impossible to import this at the top
+ from deeplabcut.pose_estimation_pytorch import available_models
+
+ nets = available_models()
+ net_filter = self.get_net_filter()
+ default_net = self.get_default_net()
+ td_prefix = "top_down_"
+ if net_filter is not None:
+ nets = [
+ n
+ for n in nets
+ if (
+ n in net_filter
+ or (
+ n.startswith(td_prefix)
+ and n[len(td_prefix) :] in net_filter
+ )
+ )
+ ]
+
+ while self.net_choice.count() > 0:
+ self.net_choice.removeItem(0)
+
+ self.net_choice.addItems(nets)
+ if default_net is None:
+ default_net = self.root.cfg.get("default_net_type", "resnet_50")
+
+ if default_net in nets:
+ self.net_choice.setCurrentIndex(nets.index(default_net))
+
+ @Slot(Engine)
+ def update_detectors(
+ self,
+ engine: Engine | None = None,
+ net_choice: str | None = None,
+ ) -> None:
+ if engine is None:
+ engine = self.root.engine
+
+ if engine == Engine.TF:
+ detectors = []
+ else:
+ # FIXME: Circular imports make it impossible to import this at the top
+ from deeplabcut.pose_estimation_pytorch import available_detectors
+
+ detectors = available_detectors()
+ det_filter = self.get_detector_filter()
+ if det_filter is not None:
+ detectors = [d for d in detectors if d in det_filter]
+
+ while self.detector_choice.count() > 0:
+ self.detector_choice.removeItem(0)
+
+ self.detector_choice.addItems(detectors)
+ default_detector = self.get_default_detector()
+ if default_detector in detectors:
+ self.detector_choice.setCurrentIndex(detectors.index(default_detector))
+ elif "ssdlite" in detectors:
+ self.detector_choice.setCurrentIndex(detectors.index("ssdlite"))
+
+ if net_choice is None:
+ net_choice = self.net_choice.currentText()
+
+ if "top_down" in net_choice:
+ self.detector_label.show()
+ self.detector_choice.show()
+ else:
+ self.detector_label.hide()
+ self.detector_choice.hide()
+
+ @Slot(Engine)
+ def update_aug_methods(self, engine: Engine) -> None:
+ methods = compat.get_available_aug_methods(engine)
+ while self.aug_choice.count() > 0:
+ self.aug_choice.removeItem(0)
+
+ self.aug_choice.addItems(methods)
+ self.aug_choice.setCurrentText(methods[0])
+
+ @Slot(Engine)
+ def update_weight_init_methods(self, engine: Engine) -> None:
+ if engine != Engine.PYTORCH:
+ self.weight_init_label.hide()
+ self.weight_init_selector.hide()
+ return
+
+ self.weight_init_label.show()
+ self.weight_init_selector.update_choices(list(_WEIGHT_INIT_OPTIONS.keys()))
+ self.weight_init_selector.show()
+
+ def get_net_filter(self) -> list[str] | None:
+ """Returns: the net type that can be used based on weight initialization"""
+ if self.root.engine != Engine.PYTORCH:
+ return None
+
+ if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS:
+ return None
+
+ weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init]
+ if "super_animal" in weight_init_cfg:
+ return dlclibrary.get_available_models(weight_init_cfg["super_animal"])
+
+ return None
+
+ def get_detector_filter(self) -> list[str] | None:
+ """Returns: the detectors that can be used based on weight initialization"""
+ if self.root.engine != Engine.PYTORCH:
+ return None
+
+ if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS:
+ return None
+
+ weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init]
+ if "super_animal" in weight_init_cfg:
+ return dlclibrary.get_available_detectors(weight_init_cfg["super_animal"])
+
+ return None
+
+ def get_default_net(self) -> str | None:
+ """Returns: the net type that can be used based on weight initialization"""
+ if self.root.engine != Engine.PYTORCH:
+ return None
+
+ if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS:
+ return None
+
+ weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init]
+ return weight_init_cfg.get("default_net")
+
+ def get_default_detector(self) -> str | None:
+ """Returns: the detector type that can be used based on weight initialization"""
+ if self.root.engine != Engine.PYTORCH:
+ return None
+
+ if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS:
+ return None
+
+ weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init]
+ return weight_init_cfg.get("default_detector")
+
+ def view_shuffles(self) -> None:
+ viewer = ShuffleMetadataViewer(root=self.root, parent=self)
+ viewer.show()
+
+
+class WeightInitializationSelector(QtWidgets.QWidget):
+ """Widget to select weight initialization"""
+
+ def __init__(self, root):
+ super().__init__()
+ self.root = root
+
+ self.weight_init_choice = QtWidgets.QComboBox()
+
+ self.memory_replay_label = QtWidgets.QLabel("With memory replay")
+ self.memory_replay_box = QtWidgets.QCheckBox()
+ self.memory_replay_label.hide()
+ self.memory_replay_box.hide()
+
+ memory_replay_layout = QtWidgets.QHBoxLayout()
+ memory_replay_layout.addWidget(self.memory_replay_label)
+ memory_replay_layout.addWidget(self.memory_replay_box)
+
+ layout = QtWidgets.QHBoxLayout()
+ layout.addWidget(self.weight_init_choice)
+ layout.addLayout(memory_replay_layout)
+ self.setLayout(layout)
+
+ self.weight_init_choice.currentTextChanged.connect(self._choice_changed)
+
+ @property
+ def weight_init(self) -> str:
+ return self.weight_init_choice.currentText()
+
+ @property
+ def with_decoder(self) -> bool:
+ weight_init_choice = self.weight_init_choice.currentText()
+ return "fine-tuning" in weight_init_choice.lower()
+
+ @property
+ def memory_replay(self) -> bool:
+ return self.memory_replay_box.isChecked()
+
+ def update_choices(self, choices: list[str]) -> None:
+ """Updates the WeightInitialization methods that can be selected"""
+ while self.weight_init_choice.count() > 0:
+ self.weight_init_choice.removeItem(0)
+ self.weight_init_choice.addItems(choices)
+
+ def get_super_animal_weight_init(
+ self,
+ net_type: str,
+ detector_type: str,
+ ) -> WeightInitialization | None:
+ """
+ Args:
+ net_type: The architecture of the pose model from which to fine-tune a
+ SuperAnimal model.
+ detector_type: The architecture of the detector from which to fine-tune a
+ SuperAnimal model.
+
+ Raises:
+ ValueError if WeightInitialization should be defined but could not be
+ created (e.g. if there's no conversion table).
+ """
+ if self.root.engine != Engine.PYTORCH:
+ return None
+
+ weight_init_choice = self.weight_init_choice.currentText()
+ if "imagenet" in weight_init_choice.lower():
+ return
+
+ weight_init_data = _WEIGHT_INIT_OPTIONS[weight_init_choice]
+ super_animal = weight_init_data["super_animal"]
+ if net_type.startswith("top_down_"):
+ net_type = net_type[len("top_down_") :]
+ try:
+ weight_init = build_weight_init(
+ self.root.cfg,
+ super_animal=super_animal,
+ model_name=net_type,
+ detector_name=detector_type,
+ with_decoder=self.with_decoder,
+ memory_replay=self.memory_replay,
+ )
+ except ValueError as err:
+ QtWidgets.QMessageBox.critical(
+ self,
+ "Error",
+ (
+ f"No Conversion table specified for {super_animal} in the project "
+ "configuration file. Please create a conversion table using the GUI"
+ ", with ``deeplabcut.modelzoo.utils.create_conversion_table``, or "
+ "by adding it to your project's configuration file manually."
+ ),
+ )
+ raise err
+
+ return weight_init
+
+ def _choice_changed(self, state: str) -> None:
+ if "fine-tuning" in str(state).lower():
+ self.memory_replay_label.show()
+ self.memory_replay_box.show()
+ else:
+ self.memory_replay_label.hide()
+ self.memory_replay_box.hide()
+
+
+class DataSplitSelector(QtWidgets.QWidget):
+ """Allows users to create training sets with the same train/test split as another"""
+
+ def __init__(self, root: QtWidgets.QMainWindow, parent: QtWidgets.QWidget):
+ super().__init__()
+ self.root = root
+ self.parent = parent
+
+ self.setToolTip(
+ "This allows you to create a shuffle where the data split is the same as "
+ "one of your existing shuffles (the images on which the model is "
+ "trained/tested are the same)."
+ )
+
+ layout = QtWidgets.QVBoxLayout()
+ layout.setSpacing(0)
+ layout.setContentsMargins(0, 0, 0, 0)
+
+ box_layout = QtWidgets.QHBoxLayout()
+ box_layout.setSpacing(0)
+ box_layout.setContentsMargins(0, 0, 0, 0)
+
+ selector_layout = QtWidgets.QHBoxLayout()
+ selector_layout.setSpacing(0)
+ selector_layout.setContentsMargins(0, 0, 0, 0)
+
+ self.shuffle_label = QtWidgets.QLabel("From shuffle:")
+ self.shuffle_label.hide()
+ self.shuffle_selector = QtWidgets.QSpinBox()
+ self.shuffle_selector.setMaximum(10_000)
+ self.shuffle_selector.setValue(0)
+ self.shuffle_selector.hide()
+
+ self.box = QtWidgets.QCheckBox(parent=self)
+ self.box.stateChanged.connect(self._checkbox_status_changed)
+ self.box_label = QtWidgets.QLabel("Use an existing data split")
+
+ box_layout.addWidget(self.box)
+ box_layout.addWidget(self.box_label)
+ selector_layout.addWidget(self.shuffle_label)
+ selector_layout.addWidget(self.shuffle_selector)
+ layout.addLayout(box_layout)
+ layout.addLayout(selector_layout)
+ self.setLayout(layout)
+
+ @property
+ def selected(self) -> bool:
+ return self.box.isChecked()
+
+ @property
+ def from_shuffle(self) -> int:
+ """The shuffle from which to copy the data split"""
+ return self.shuffle_selector.value()
+
+ def _checkbox_status_changed(self, state: int) -> None:
+ if Qt.CheckState(state) == Qt.Checked:
+ self.shuffle_selector.show()
+ self.shuffle_label.show()
+ else:
+ self.shuffle_selector.hide()
+ self.shuffle_label.hide()
+
def _create_message_box(text, info_text):
msg = QtWidgets.QMessageBox()
@@ -176,3 +719,56 @@ def _create_message_box(text, info_text):
msg.setWindowIcon(QIcon(logo))
msg.setStandardButtons(QtWidgets.QMessageBox.Ok)
return msg
+
+
+def _create_confirmation_box(title, description):
+ msg = QtWidgets.QMessageBox()
+ msg.setIcon(QtWidgets.QMessageBox.Information)
+ msg.setText(title)
+ msg.setInformativeText(description)
+
+ msg.setWindowTitle("Confirmation")
+ msg.setMinimumWidth(900)
+ logo_dir = os.path.dirname(os.path.realpath("logo.png")) + os.path.sep
+ logo = logo_dir + "/assets/logo.png"
+ msg.setWindowIcon(QIcon(logo))
+ msg.setStandardButtons(QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No)
+ return msg
+
+
+_WEIGHT_INIT_OPTIONS = { # FIXME - Generate dynamically
+ "Transfer Learning - ImageNet": {
+ "model_filter": None,
+ "detector_filter": None,
+ },
+ "Transfer Learning - SuperAnimal Bird": {
+ "default_net": "top_down_resnet_50",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_bird",
+ },
+ "Transfer Learning - SuperAnimal Quadruped": {
+ "default_net": "top_down_hrnet_w32",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_quadruped",
+ },
+ "Transfer Learning - SuperAnimal TopViewMouse": {
+ "default_net": "top_down_hrnet_w32",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_topviewmouse",
+ },
+ "Fine-tuning - SuperAnimal Bird": {
+ "default_net": "top_down_resnet_50",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_bird",
+ },
+ "Fine-tuning - SuperAnimal Quadruped": {
+ "default_net": "top_down_hrnet_w32",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_quadruped",
+ },
+ "Fine-tuning - SuperAnimal TopViewMouse": {
+ "default_net": "top_down_hrnet_w32",
+ "default_detector": "fasterrcnn_mobilenet_v3_large_fpn",
+ "super_animal": "superanimal_topviewmouse",
+ },
+}
diff --git a/deeplabcut/gui/tabs/create_videos.py b/deeplabcut/gui/tabs/create_videos.py
index d8299b9445..9cdf92789e 100644
--- a/deeplabcut/gui/tabs/create_videos.py
+++ b/deeplabcut/gui/tabs/create_videos.py
@@ -134,7 +134,7 @@ def _generate_layout_video_parameters(self, layout):
# Skeleton
self.draw_skeleton_checkbox = QtWidgets.QCheckBox("Draw skeleton")
- self.draw_skeleton_checkbox.setCheckState(Qt.Checked)
+ self.draw_skeleton_checkbox.setCheckState(Qt.Unchecked)
self.draw_skeleton_checkbox.stateChanged.connect(self.update_draw_skeleton)
tmp_layout.addWidget(self.draw_skeleton_checkbox)
@@ -146,6 +146,24 @@ def _generate_layout_video_parameters(self, layout):
)
tmp_layout.addWidget(self.use_filtered_data_checkbox)
+ # Selector for p-cutoff
+ pcutoff_widget = QtWidgets.QWidget()
+ pcutoff_layout = _create_horizontal_layout(margins=(0, 0, 0, 0))
+ pcutoff_label = QtWidgets.QLabel("Plotting confidence cutoff (pcutoff)")
+ self.pcutoff_selector = QtWidgets.QDoubleSpinBox()
+ self.pcutoff_selector.setMinimum(0.0)
+ self.pcutoff_selector.setMaximum(1.0)
+ self.pcutoff_selector.setValue(0.6)
+ self.pcutoff_selector.setSingleStep(0.05)
+ pcutoff_layout.addWidget(pcutoff_label)
+ pcutoff_layout.addWidget(self.pcutoff_selector)
+ pcutoff_widget.setLayout(pcutoff_layout)
+ pcutoff_widget.setToolTip(
+ "This value sets the confidence threshold, above which predictions are "
+ "shown in the labeled videos."
+ )
+ tmp_layout.addWidget(pcutoff_widget)
+
# Plot trajectories
self.plot_trajectories = QtWidgets.QCheckBox("Plot trajectories")
self.plot_trajectories.setCheckState(Qt.Unchecked)
@@ -179,11 +197,11 @@ def _generate_layout_video_parameters(self, layout):
layout.addLayout(tmp_layout, Qt.AlignLeft)
def update_high_quality_video(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"High quality {s}.")
def update_plot_trajectory_choice(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Plot trajectories {s}.")
def update_selected_bodyparts(self):
@@ -196,7 +214,7 @@ def update_selected_bodyparts(self):
self.bodyparts_to_use = selected_bodyparts
def update_use_all_bodyparts(self, s):
- if s == Qt.Checked:
+ if Qt.CheckState(s) == Qt.Checked:
self.bodyparts_list_widget.setEnabled(False)
self.bodyparts_list_widget.hide()
self.root.logger.info("Plot all bodyparts ENABLED.")
@@ -207,15 +225,15 @@ def update_use_all_bodyparts(self, s):
self.root.logger.info("Plot all bodyparts DISABLED.")
def update_use_filtered_data(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Use filtered data {s}")
def update_draw_skeleton(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Draw skeleton {s}")
def update_overwrite_videos(self, state):
- s = "ENABLED" if state == Qt.Checked else "DISABLED"
+ s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED"
self.root.logger.info(f"Overwrite videos {s}")
def update_color_by(self, text):
@@ -259,10 +277,12 @@ def create_videos(self):
shuffle=shuffle,
filtered=filtered,
save_frames=self.create_high_quality_video.isChecked(),
+ pcutoff=self.pcutoff_selector.value(),
displayedbodyparts=bodyparts,
draw_skeleton=self.draw_skeleton_checkbox.isChecked(),
trailpoints=trailpoints,
color_by=color_by,
+ overwrite=self.overwrite_videos.isChecked(),
)
if all(videos_created):
self.root.writer.write("Labeled videos created.")
@@ -280,6 +300,7 @@ def create_videos(self):
shuffle=shuffle,
filtered=filtered,
displayedbodyparts=bodyparts,
+ pcutoff=self.pcutoff_selector.value(),
)
def build_skeleton(self, *args):
diff --git a/deeplabcut/gui/tabs/docs.py b/deeplabcut/gui/tabs/docs.py
new file mode 100644
index 0000000000..1f52118e19
--- /dev/null
+++ b/deeplabcut/gui/tabs/docs.py
@@ -0,0 +1,15 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+BASE_URL = "https://deeplabcut.github.io/DeepLabCut/docs/"
+README = "https://deeplabcut.github.io/DeepLabCut/README.html"
+URL_3D = BASE_URL + "Overviewof3D.html"
+URL_MA_CONFIGURE = BASE_URL + "maDLC_UserGuide.html#configure-the-project"
+URL_USE_GUIDE_SCENARIO = BASE_URL + "UseOverviewGuide.html#what-scenario-do-you-have"
diff --git a/deeplabcut/gui/tabs/evaluate_network.py b/deeplabcut/gui/tabs/evaluate_network.py
index 36389224db..9b46ee09ce 100644
--- a/deeplabcut/gui/tabs/evaluate_network.py
+++ b/deeplabcut/gui/tabs/evaluate_network.py
@@ -8,17 +8,21 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+from __future__ import annotations
+
import os
import matplotlib.image as mpimg
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas,
)
from matplotlib.figure import Figure
+from pathlib import Path
from PySide6 import QtWidgets
-from PySide6.QtCore import Qt
+from PySide6.QtCore import Qt, Slot
import deeplabcut
-from deeplabcut.utils.auxiliaryfunctions import get_evaluation_folder
+from deeplabcut.core.engine import Engine
+from deeplabcut.gui.displays.selected_shuffle_display import SelectedShuffleDisplay
from deeplabcut.gui.components import (
BodypartListWidget,
DefaultTab,
@@ -27,7 +31,8 @@
_create_label_widget,
_create_vertical_layout,
)
-from deeplabcut.gui.widgets import ConfigEditor
+from deeplabcut.gui.widgets import ConfigEditor, launch_napari
+from deeplabcut.utils import auxiliaryfunctions
class GridCanvas(QtWidgets.QDialog):
@@ -91,6 +96,9 @@ def _set_page(self):
self.help_button.clicked.connect(self.show_help_dialog)
self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft)
+ self.root.engine_change.connect(self._on_engine_change)
+ self._on_engine_change(self.root.engine)
+
def show_help_dialog(self):
dialog = QtWidgets.QDialog(self)
layout = QtWidgets.QVBoxLayout()
@@ -107,9 +115,11 @@ def show_help_dialog(self):
def _generate_layout_attributes(self, layout):
opt_text = QtWidgets.QLabel("Shuffle")
self.shuffle = ShuffleSpinBox(root=self.root, parent=self)
+ self.shuffle_display = SelectedShuffleDisplay(self.root, row_margin=0)
layout.addWidget(opt_text)
layout.addWidget(self.shuffle)
+ layout.addWidget(self.shuffle_display)
def open_inferencecfg_editor(self):
editor = ConfigEditor(self.root.inference_cfg_path)
@@ -124,7 +134,7 @@ def plot_maps(self):
dest_folder = os.path.join(
self.root.project_folder,
str(
- get_evaluation_folder(
+ auxiliaryfunctions.get_evaluation_folder(
self.root.cfg["TrainingFraction"][0], shuffle, self.root.cfg
)
),
@@ -159,19 +169,19 @@ def _generate_additional_attributes(self, layout):
layout.addWidget(self.bodyparts_list_widget, alignment=Qt.AlignLeft)
def update_map_choice(self, state):
- if state == Qt.Checked:
+ if Qt.CheckState(state) == Qt.Checked:
self.root.logger.info("Plot scoremaps ENABLED")
else:
self.root.logger.info("Plot predictions DISABLED")
def update_plot_predictions(self, s):
- if s == Qt.Checked:
+ if Qt.CheckState(s) == Qt.Checked:
self.root.logger.info("Plot predictions ENABLED")
else:
self.root.logger.info("Plot predictions DISABLED")
def update_bodypart_choice(self, s):
- if s == Qt.Checked:
+ if Qt.CheckState(s) == Qt.Checked:
self.bodyparts_list_widget.setEnabled(False)
self.bodyparts_list_widget.hide()
self.root.logger.info("Use all bodyparts")
@@ -184,8 +194,7 @@ def update_bodypart_choice(self, s):
def evaluate_network(self):
config = self.root.config
-
- Shuffles = [self.root.shuffle_value]
+ shuffle = self.root.shuffle_value
plotting = self.plot_predictions.isChecked()
bodyparts_to_use = "all"
@@ -197,8 +206,38 @@ def evaluate_network(self):
deeplabcut.evaluate_network(
config,
- Shuffles=Shuffles,
+ Shuffles=[shuffle],
plotting=plotting,
show_errors=True,
comparisonbodyparts=bodyparts_to_use,
)
+
+ if plotting:
+ project_cfg = self.root.cfg
+ eval_folder = auxiliaryfunctions.get_evaluation_folder(
+ trainFraction=project_cfg["TrainingFraction"][0],
+ shuffle=shuffle,
+ cfg=project_cfg,
+ )
+ scorer, _ = auxiliaryfunctions.get_scorer_name(
+ cfg=project_cfg,
+ shuffle=shuffle,
+ trainFraction=project_cfg["TrainingFraction"][0],
+ )
+
+ image_dir = (
+ Path(self.root.project_folder)
+ / eval_folder
+ / f"LabeledImages_{scorer}"
+ )
+ labeled_images = [str(p) for p in image_dir.rglob("*.png")]
+ if len(labeled_images) > 0:
+ _ = launch_napari(image_dir)
+
+ @Slot(Engine)
+ def _on_engine_change(self, engine: Engine) -> None:
+ if engine == Engine.PYTORCH:
+ self.opt_button.hide()
+ return
+
+ self.opt_button.show()
diff --git a/deeplabcut/gui/tabs/label_frames.py b/deeplabcut/gui/tabs/label_frames.py
index 52261ad3b1..9eb49b276c 100644
--- a/deeplabcut/gui/tabs/label_frames.py
+++ b/deeplabcut/gui/tabs/label_frames.py
@@ -19,6 +19,7 @@
from deeplabcut.generate_training_dataset import check_labels
from deeplabcut.gui.components import DefaultTab
from deeplabcut.gui.widgets import launch_napari
+from deeplabcut.utils.skeleton import SkeletonBuilder
def label_frames(
@@ -110,8 +111,11 @@ def _set_page(self):
self.label_frames_btn.clicked.connect(self.label_frames)
self.check_labels_btn = QtWidgets.QPushButton("Check Labels")
self.check_labels_btn.clicked.connect(self.check_labels)
+ self.build_skeleton_btn = QtWidgets.QPushButton("Build skeleton")
+ self.build_skeleton_btn.clicked.connect(self.build_skeleton)
self.main_layout.addWidget(self.label_frames_btn, alignment=Qt.AlignLeft)
self.main_layout.addWidget(self.check_labels_btn, alignment=Qt.AlignLeft)
+ self.main_layout.addWidget(self.build_skeleton_btn, alignment=Qt.AlignLeft)
def log_color_by_option(self, choice):
self.root.logger.info(f"Labeled images will by colored by {choice.upper()}")
@@ -136,3 +140,8 @@ def label_frames(self):
def check_labels(self):
check_labels(self.root.config, visualizeindividuals=self.root.is_multianimal)
+ labeled_images = (Path(self.root.config).parent / "labeled-data").rglob("*_labeled/*.png")
+ _ = launch_napari(labeled_images, plugin="napari", stack=True)
+
+ def build_skeleton(self, *args):
+ SkeletonBuilder(self.root.config)
\ No newline at end of file
diff --git a/deeplabcut/gui/tabs/modelzoo.py b/deeplabcut/gui/tabs/modelzoo.py
index 97ddc2fe07..13039517f2 100644
--- a/deeplabcut/gui/tabs/modelzoo.py
+++ b/deeplabcut/gui/tabs/modelzoo.py
@@ -9,21 +9,25 @@
# Licensed under GNU Lesser General Public License v3.0
#
import os
+import webbrowser
from functools import partial
-import deeplabcut
+import dlclibrary
from PySide6 import QtWidgets
-from PySide6.QtCore import Qt, Signal, QTimer, QRegularExpression
-from PySide6.QtGui import QPixmap, QRegularExpressionValidator
+from PySide6.QtCore import QRegularExpression, Qt, QTimer, Signal, Slot
+from PySide6.QtGui import QIcon, QPixmap, QRegularExpressionValidator
+
+import deeplabcut
+from deeplabcut.core.engine import Engine
+from deeplabcut.gui import BASE_DIR
from deeplabcut.gui.components import (
+ _create_grid_layout,
+ _create_label_widget,
DefaultTab,
VideoSelectionWidget,
- _create_label_widget,
- _create_grid_layout,
)
-from deeplabcut.gui import BASE_DIR
from deeplabcut.gui.utils import move_to_separate_thread
-from deeplabcut.modelzoo.utils import parse_available_supermodels
+from deeplabcut.gui.widgets import ClickableLabel
class RegExpValidator(QRegularExpressionValidator):
@@ -40,6 +44,11 @@ def __init__(self, root, parent, h1_description):
super().__init__(root, parent, h1_description)
self._val_pattern = QRegularExpression(r"(\d{3,5},\s*)+\d{3,5}")
self._set_page()
+ self.root.engine_change.connect(self._on_engine_change)
+ self.root.engine_change.connect(self._update_available_models)
+ self._update_pose_models(self.model_combo.currentText())
+ self._update_detectors(self.model_combo.currentText())
+ self._destfolder = None
@property
def files(self):
@@ -50,18 +59,85 @@ def _set_page(self):
self.video_selection_widget = VideoSelectionWidget(self.root, self)
self.main_layout.addWidget(self.video_selection_widget)
- model_settings_layout = _create_grid_layout(margins=(20, 0, 0, 0))
+ self._build_common_attributes()
+ self._build_tf_attributes()
+ self._build_torch_attributes()
+
+ self.run_button = QtWidgets.QPushButton("Run")
+ self.run_button.clicked.connect(self.run_video_adaptation)
+ self.main_layout.addWidget(self.run_button, alignment=Qt.AlignRight)
+
+ self.home_button = QtWidgets.QPushButton("Return to Welcome page")
+ self.home_button.clicked.connect(self.root._generate_welcome_page)
+ self.main_layout.addWidget(self.home_button, alignment=Qt.AlignLeft)
+ self.help_button = QtWidgets.QPushButton("Help")
+ self.help_button.clicked.connect(self.show_help_dialog)
+ self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft)
+ self.go_to_button = QtWidgets.QPushButton("Read Documentation")
+ # go to url https://deeplabcut.github.io/DeepLabCut/docs/ModelZoo.html#about-the-superanimal-models when button is clicked
+ self.go_to_button.clicked.connect(
+ lambda: webbrowser.open(
+ "https://deeplabcut.github.io/DeepLabCut/docs/ModelZoo.html#about-the-superanimal-models"
+ )
+ )
+ self.main_layout.addWidget(self.go_to_button, alignment=Qt.AlignLeft)
+ self._on_engine_change(self.root.engine)
+
+ def _build_common_attributes(self) -> None:
+ settings_layout = _create_grid_layout(margins=(20, 0, 0, 0))
section_title = _create_label_widget(
"Supermodel Settings", "font:bold", (0, 50, 0, 0)
)
model_combo_text = QtWidgets.QLabel("Supermodel name")
+ model_combo_text.setMinimumWidth(300)
self.model_combo = QtWidgets.QComboBox()
- supermodels = parse_available_supermodels()
- self.model_combo.addItems(supermodels.keys())
+ self.model_combo.setMinimumWidth(250)
+
+ net_type_text = QtWidgets.QLabel("Net Type")
+ net_type_text.setMinimumWidth(300)
+ self.net_type_selector = QtWidgets.QComboBox()
+
+ self.detector_type_text = QtWidgets.QLabel("Detector Type")
+ self.detector_type_text.setMinimumWidth(300)
+ self.detector_type_selector = QtWidgets.QComboBox()
+
+ loc_label = ClickableLabel("Folder to store results:", parent=self)
+ loc_label.signal.connect(self.select_folder)
+ self.loc_line = QtWidgets.QLineEdit(
+ "",
+ self,
+ )
+ self.loc_line.setReadOnly(True)
+ action = self.loc_line.addAction(
+ QIcon(os.path.join(BASE_DIR, "assets", "icons", "open2.png")),
+ QtWidgets.QLineEdit.TrailingPosition,
+ )
+ action.triggered.connect(self.select_folder)
+
+ settings_layout.addWidget(section_title, 0, 0)
+ settings_layout.addWidget(model_combo_text, 1, 0)
+ settings_layout.addWidget(self.model_combo, 1, 1)
+ settings_layout.addWidget(net_type_text, 2, 0)
+ settings_layout.addWidget(self.net_type_selector, 2, 1)
+ settings_layout.addWidget(self.detector_type_text, 3, 0)
+ settings_layout.addWidget(self.detector_type_selector, 3, 1)
+
+ settings_layout.addWidget(loc_label, 4, 0)
+ settings_layout.addWidget(self.loc_line, 4, 1)
+
+ self.settings_widget = QtWidgets.QWidget()
+ self.settings_widget.setLayout(settings_layout)
+ self.main_layout.addWidget(self.settings_widget)
+ self.model_combo.currentTextChanged.connect(self._update_pose_models)
+ self.model_combo.currentTextChanged.connect(self._update_detectors)
+
+ def _build_tf_attributes(self) -> None:
+ model_settings_layout = _create_grid_layout(margins=(20, 0, 0, 0))
scales_label = QtWidgets.QLabel("Scale list")
+ scales_label.setMinimumWidth(300)
self.scales_line = QtWidgets.QLineEdit("", parent=self)
self.scales_line.setPlaceholderText(
"Optionally input a list of integer sizes separated by commas..."
@@ -77,7 +153,8 @@ def _set_page(self):
).scaledToWidth(30)
)
tooltip_label.setToolTip(
- "Approximate animal sizes in pixels, for spatial pyramid search. If left blank, defaults to video height +/- 50 pixels",
+ "Approximate animal sizes in pixels, for spatial pyramid search. If left "
+ "blank, defaults to video height +/- 50 pixels"
)
self.adapt_checkbox = QtWidgets.QCheckBox("Use video adaptation")
@@ -95,6 +172,7 @@ def _set_page(self):
self.pseudo_threshold_spinbox.setMaximumWidth(300)
adapt_iter_label = QtWidgets.QLabel("Number of adaptation iterations")
+ adapt_iter_label.setMinimumWidth(300)
self.adapt_iter_spinbox = QtWidgets.QSpinBox()
self.adapt_iter_spinbox.setRange(100, 10000)
self.adapt_iter_spinbox.setValue(1000)
@@ -102,26 +180,72 @@ def _set_page(self):
self.adapt_iter_spinbox.setGroupSeparatorShown(True)
self.adapt_iter_spinbox.setMaximumWidth(300)
- model_settings_layout.addWidget(section_title, 0, 0)
- model_settings_layout.addWidget(model_combo_text, 1, 0)
- model_settings_layout.addWidget(self.model_combo, 1, 1)
- model_settings_layout.addWidget(scales_label, 2, 0)
- model_settings_layout.addWidget(self.scales_line, 2, 1)
- model_settings_layout.addWidget(tooltip_label, 2, 2)
- model_settings_layout.addWidget(self.adapt_checkbox, 3, 0)
- model_settings_layout.addWidget(pseudo_threshold_label, 4, 0)
- model_settings_layout.addWidget(self.pseudo_threshold_spinbox, 4, 1)
- model_settings_layout.addWidget(adapt_iter_label, 5, 0)
- model_settings_layout.addWidget(self.adapt_iter_spinbox, 5, 1)
- self.main_layout.addLayout(model_settings_layout)
+ model_settings_layout.addWidget(scales_label, 1, 0)
+ model_settings_layout.addWidget(self.scales_line, 1, 1)
+ model_settings_layout.addWidget(tooltip_label, 1, 2)
+ model_settings_layout.addWidget(self.adapt_checkbox, 2, 0)
+ model_settings_layout.addWidget(pseudo_threshold_label, 3, 0)
+ model_settings_layout.addWidget(self.pseudo_threshold_spinbox, 3, 1)
+ model_settings_layout.addWidget(adapt_iter_label, 4, 0)
+ model_settings_layout.addWidget(self.adapt_iter_spinbox, 4, 1)
+ self.tf_widget = QtWidgets.QWidget()
+ self.tf_widget.setLayout(model_settings_layout)
+ self.tf_widget.hide()
+ self.main_layout.addWidget(self.tf_widget)
- self.run_button = QtWidgets.QPushButton("Run")
- self.run_button.clicked.connect(self.run_video_adaptation)
- self.main_layout.addWidget(self.run_button, alignment=Qt.AlignRight)
+ def _build_torch_attributes(self) -> None:
+ torch_settings_layout = _create_grid_layout(margins=(20, 0, 0, 0))
- self.help_button = QtWidgets.QPushButton("Help")
- self.help_button.clicked.connect(self.show_help_dialog)
- self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft)
+ self.torch_adapt_checkbox = QtWidgets.QCheckBox("Use video adaptation")
+ self.torch_adapt_checkbox.setChecked(True)
+
+ pseudo_threshold_label = QtWidgets.QLabel("Pseudo-label confidence threshold")
+ pseudo_threshold_label.setMinimumWidth(300)
+ self.torch_pseudo_threshold_spinbox = QtWidgets.QDoubleSpinBox(
+ decimals=2,
+ minimum=0.01,
+ maximum=1.0,
+ singleStep=0.05,
+ value=0.1,
+ wrapping=True,
+ )
+ self.torch_pseudo_threshold_spinbox.setMaximumWidth(300)
+
+ adapt_epoch_label = QtWidgets.QLabel("Number of adaptation epochs")
+ adapt_epoch_label.setMinimumWidth(300)
+ self.torch_adapt_epoch_spinbox = QtWidgets.QSpinBox()
+ self.torch_adapt_epoch_spinbox.setRange(1, 50)
+ self.torch_adapt_epoch_spinbox.setValue(4)
+ self.torch_adapt_epoch_spinbox.setMaximumWidth(300)
+
+ adapt_det_epoch_label = QtWidgets.QLabel("Number of detector adaptation epochs")
+ adapt_det_epoch_label.setMinimumWidth(300)
+ self.torch_adapt_det_epoch_spinbox = QtWidgets.QSpinBox()
+ self.torch_adapt_det_epoch_spinbox.setRange(1, 50)
+ self.torch_adapt_det_epoch_spinbox.setValue(4)
+ self.torch_adapt_det_epoch_spinbox.setMaximumWidth(300)
+
+ torch_settings_layout.addWidget(self.torch_adapt_checkbox, 1, 0)
+ torch_settings_layout.addWidget(pseudo_threshold_label, 2, 0)
+ torch_settings_layout.addWidget(self.torch_pseudo_threshold_spinbox, 2, 1)
+ torch_settings_layout.addWidget(adapt_epoch_label, 3, 0)
+ torch_settings_layout.addWidget(self.torch_adapt_epoch_spinbox, 3, 1)
+ torch_settings_layout.addWidget(adapt_det_epoch_label, 4, 0)
+ torch_settings_layout.addWidget(self.torch_adapt_det_epoch_spinbox, 4, 1)
+ self.torch_widget = QtWidgets.QWidget()
+ self.torch_widget.setLayout(torch_settings_layout)
+ self.torch_widget.hide()
+ self.main_layout.addWidget(self.torch_widget)
+
+ def select_folder(self):
+ dirname = QtWidgets.QFileDialog.getExistingDirectory(
+ self, "Please select a folder", self.root.project_folder
+ )
+ if not dirname:
+ return
+
+ self._destfolder = dirname
+ self.loc_line.setText(dirname)
def show_help_dialog(self):
dialog = QtWidgets.QDialog(self)
@@ -158,31 +282,121 @@ def run_video_adaptation(self):
msg.exec_()
return
- scales = []
- scales_ = self.scales_line.text()
- if scales_:
- if (
- self.scales_line.validator().validate(scales_, 0)[0]
- == RegExpValidator.Acceptable
- ):
- scales = list(map(int, scales_.split(",")))
supermodel_name = self.model_combo.currentText()
videotype = self.video_selection_widget.videotype_widget.currentText()
+ kwargs = self._gather_kwargs()
- func = partial(
- deeplabcut.video_inference_superanimal,
- videos,
- supermodel_name,
- videotype=videotype,
- video_adapt=self.adapt_checkbox.isChecked(),
- scale_list=scales,
- pseudo_threshold=self.pseudo_threshold_spinbox.value(),
- adapt_iterations=self.adapt_iter_spinbox.value(),
- )
+ can_run_in_background = False
+ if can_run_in_background:
+ func = partial(
+ deeplabcut.video_inference_superanimal,
+ videos,
+ supermodel_name,
+ videotype=videotype,
+ dest_folder=self._destfolder,
+ **kwargs,
+ )
+
+ self.worker, self.thread = move_to_separate_thread(func)
+ self.worker.finished.connect(self.signal_analysis_complete)
+ self.thread.start()
+ self.run_button.setEnabled(False)
+ self.root._progress_bar.show()
+ else:
+ print(f"Calling video_inference_superanimal with kwargs={kwargs}")
+ deeplabcut.video_inference_superanimal(
+ videos,
+ supermodel_name,
+ videotype=videotype,
+ dest_folder=self._destfolder,
+ **kwargs,
+ )
+ self.signal_analysis_complete()
+
+ def signal_analysis_complete(self):
+ self.run_button.setEnabled(True)
+ self.root._progress_bar.hide()
+ msg = QtWidgets.QMessageBox(text="SuperAnimal video inference complete!")
+ msg.setIcon(QtWidgets.QMessageBox.Information)
+ msg.exec_()
+
+ def _gather_kwargs(self) -> dict:
+ kwargs = dict(model_name=self.net_type_selector.currentText())
+ if self.root.engine == Engine.TF:
+ scales = []
+ scales_ = self.scales_line.text()
+ if scales_:
+ if (
+ self.scales_line.validator().validate(scales_, 0)[0]
+ == RegExpValidator.Acceptable
+ ):
+ scales = list(map(int, scales_.split(",")))
+ kwargs["scale_list"] = scales
+ kwargs["video_adapt"] = self.adapt_checkbox.isChecked()
+ kwargs["pseudo_threshold"] = self.pseudo_threshold_spinbox.value()
+ kwargs["adapt_iterations"] = self.adapt_iter_spinbox.value()
+ else:
+ kwargs["detector_name"] = self.detector_type_selector.currentText()
+ kwargs["video_adapt"] = self.torch_adapt_checkbox.isChecked()
+ kwargs["pseudo_threshold"] = self.torch_pseudo_threshold_spinbox.value()
+ kwargs["detector_epochs"] = self.torch_adapt_det_epoch_spinbox.value()
+ kwargs["pose_epochs"] = self.torch_adapt_epoch_spinbox.value()
+
+ return kwargs
+
+ def _update_available_models(self, engine: Engine) -> None:
+ current_dataset = self.model_combo.currentText()
+
+ while self.model_combo.count() > 0:
+ self.model_combo.removeItem(0)
+
+ if engine == Engine.TF:
+ supermodels = ["superanimal_topviewmouse", "superanimal_quadruped"]
+ else:
+ supermodels = dlclibrary.get_available_datasets()
+
+ self.model_combo.addItems(supermodels)
+ if current_dataset in supermodels:
+ self.model_combo.setCurrentIndex(supermodels.index(current_dataset))
+
+ def _update_pose_models(self, super_animal: str) -> None:
+ while self.net_type_selector.count() > 0:
+ self.net_type_selector.removeItem(0)
+
+ if len(super_animal) == 0:
+ return
+
+ if self.root.engine == Engine.TF:
+ self.net_type_selector.addItems(["dlcrnet"])
+ else:
+ self.net_type_selector.addItems(
+ dlclibrary.get_available_models(super_animal)
+ )
+
+ def _update_detectors(self, super_animal: str) -> None:
+ while self.detector_type_selector.count() > 0:
+ self.detector_type_selector.removeItem(0)
+
+ if len(super_animal) == 0:
+ return
+
+ if self.root.engine == Engine.TF:
+ self.detector_type_selector.addItems(["dlcrnet"])
+ else:
+ self.detector_type_selector.addItems(
+ dlclibrary.get_available_detectors(super_animal)
+ )
- self.worker, self.thread = move_to_separate_thread(func)
- self.worker.finished.connect(lambda: self.run_button.setEnabled(True))
- self.worker.finished.connect(lambda: self.root._progress_bar.hide())
- self.thread.start()
- self.run_button.setEnabled(False)
- self.root._progress_bar.show()
+ @Slot(Engine)
+ def _on_engine_change(self, engine: Engine) -> None:
+ self._update_available_models(engine)
+ if engine == Engine.PYTORCH:
+ self.tf_widget.hide()
+ self.detector_type_text.show()
+ self.detector_type_selector.show()
+ self.torch_widget.show()
+ else:
+ self.torch_widget.hide()
+ self.detector_type_text.hide()
+ self.detector_type_selector.hide()
+ self.tf_widget.show()
diff --git a/deeplabcut/gui/tabs/refine_tracklets.py b/deeplabcut/gui/tabs/refine_tracklets.py
index e32d0bf906..74227b3512 100644
--- a/deeplabcut/gui/tabs/refine_tracklets.py
+++ b/deeplabcut/gui/tabs/refine_tracklets.py
@@ -24,7 +24,7 @@
)
import deeplabcut
-from deeplabcut.pose_estimation_tensorflow.lib import trackingutils
+from deeplabcut.core import trackingutils
from deeplabcut.utils.auxiliaryfunctions import GetScorerName
diff --git a/deeplabcut/gui/tabs/train_network.py b/deeplabcut/gui/tabs/train_network.py
index a4ffbcef52..212302eafd 100644
--- a/deeplabcut/gui/tabs/train_network.py
+++ b/deeplabcut/gui/tabs/train_network.py
@@ -8,47 +8,113 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+from __future__ import annotations
+
import os
-from pathlib import Path
+from dataclasses import dataclass
from PySide6 import QtWidgets
-from PySide6.QtCore import Qt
+from PySide6.QtCore import Qt, Slot
from PySide6.QtGui import QIcon
+import deeplabcut.compat as compat
+from deeplabcut.core.engine import Engine
from deeplabcut.gui.components import (
DefaultTab,
ShuffleSpinBox,
+ SnapshotSelectionWidget,
_create_grid_layout,
_create_label_widget,
)
+from deeplabcut.gui.displays.selected_shuffle_display import SelectedShuffleDisplay
from deeplabcut.gui.widgets import ConfigEditor
-import deeplabcut
-from deeplabcut.utils import auxiliaryfunctions
+
+@dataclass
+class IntTrainAttribute:
+ label: str
+ fn_key: str
+ default: int
+ min: int
+ max: int
+ tooltip: str | None = None
+
+
+@dataclass
+class TrainAttributeRow:
+ attributes: list[IntTrainAttribute]
+ description: str | None = None
+ show_when_cfg: tuple[str, str] | None = None
class TrainNetwork(DefaultTab):
def __init__(self, root, parent, h1_description):
super(TrainNetwork, self).__init__(root, parent, h1_description)
+ self._shuffle: ShuffleSpinBox = ShuffleSpinBox(root=self.root, parent=self)
+ self._shuffle_display = SelectedShuffleDisplay(self.root)
- # use the default pose_cfg file for default values
- default_pose_cfg_path = os.path.join(
- Path(deeplabcut.__file__).parent, "pose_cfg.yaml"
- )
- pose_cfg = auxiliaryfunctions.read_plainconfig(default_pose_cfg_path)
- self.display_iters = str(pose_cfg["display_iters"])
- self.save_iters = str(pose_cfg["save_iters"])
- self.max_iters = str(pose_cfg["multi_step"][-1][-1])
-
+ self._attribute_layouts: dict[Engine, QtWidgets.QWidget] = {}
+ self._attribute_kwargs: dict[Engine, dict] = {}
+ self._rows_with_requirements: list = []
self._set_page()
+ self.root.engine_change.connect(self._on_engine_change)
+ self._shuffle_display.pose_cfg_signal.connect(self._pose_cfg_change)
+
+ @Slot(Engine)
+ def _on_engine_change(self, engine: Engine) -> None:
+ for e, layout in self._attribute_layouts.items():
+ if e == engine:
+ layout.show()
+ else:
+ layout.hide()
+ self._update_snapshot_selection_widgets_visibility()
+
+ def _update_snapshot_selection_widgets_visibility(self):
+ if self.root.engine == Engine.PYTORCH:
+ self.resume_from_snapshot_label.show()
+ self.snapshot_selection_widget.show()
+ # Display detector snapshot selection widget only if in Top-Down mode
+ if self._shuffle_display.pose_cfg.get("method", "").lower() == "td":
+ self.detector_snapshot_selection_widget.show()
+ else:
+ self.detector_snapshot_selection_widget.hide()
+ else:
+ self.resume_from_snapshot_label.hide()
+ self.snapshot_selection_widget.hide()
+ self.detector_snapshot_selection_widget.hide()
+
def _set_page(self):
self.main_layout.addWidget(_create_label_widget("Attributes", "font:bold"))
- self.layout_attributes = _create_grid_layout(margins=(20, 0, 0, 0))
- self._generate_layout_attributes(self.layout_attributes)
- self.main_layout.addLayout(self.layout_attributes)
+ self._generate_layout_attributes()
+
+ self.resume_from_snapshot_label = _create_label_widget(
+ "[Optional]: Select a snapshot to resume training from", "font:bold"
+ )
+ self.resume_from_snapshot_label.setToolTip(
+ ""
+ "If you've already trained a model on this shuffle, you can continue training it instead of starting "
+ "from scratch again. When using top-down models, you can also choose a detector to resume training from."
+ " "
+ )
+ self.main_layout.addWidget(self.resume_from_snapshot_label)
+
+ self.snapshot_selection_widget = SnapshotSelectionWidget(
+ self.root, self, margins=(30, 0, 0, 0), select_button_text="Select snapshot"
+ )
+ self.main_layout.addWidget(self.snapshot_selection_widget)
+
+ self.detector_snapshot_selection_widget = SnapshotSelectionWidget(
+ self.root,
+ self,
+ margins=(30, 0, 0, 0),
+ select_button_text="Select detector snapshot",
+ )
+ self.main_layout.addWidget(self.detector_snapshot_selection_widget)
- self.main_layout.addWidget(_create_label_widget("")) # dummy label
+ self._pose_cfg_change(
+ self._shuffle_display.pose_cfg
+ ) # also calls _update_snapshot_selection_widgets_visibility
self.edit_posecfg_btn = QtWidgets.QPushButton("Edit pose_cfg.yaml")
self.edit_posecfg_btn.setMinimumWidth(150)
@@ -68,7 +134,7 @@ def _set_page(self):
def show_help_dialog(self):
dialog = QtWidgets.QDialog(self)
layout = QtWidgets.QVBoxLayout()
- label = QtWidgets.QLabel(deeplabcut.train_network.__doc__, self)
+ label = QtWidgets.QLabel(compat.train_network.__doc__, self)
scroll = QtWidgets.QScrollArea()
scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
@@ -78,66 +144,79 @@ def show_help_dialog(self):
dialog.setLayout(layout)
dialog.exec_()
- def _generate_layout_attributes(self, layout):
- # Shuffle
+ def _generate_layout_attributes(self) -> None:
+ row_margin = 25
+
+ # top layout
shuffle_label = QtWidgets.QLabel("Shuffle")
- self.shuffle = ShuffleSpinBox(root=self.root, parent=self)
-
- # Display iterations
- dispiters_label = QtWidgets.QLabel("Display iterations")
- self.display_iters_spin = QtWidgets.QSpinBox()
- self.display_iters_spin.setMinimum(1)
- self.display_iters_spin.setMaximum(int(self.max_iters))
- self.display_iters_spin.setValue(1000)
- self.display_iters_spin.valueChanged.connect(self.log_display_iters)
-
- # Save iterations
- saveiters_label = QtWidgets.QLabel("Save iterations")
- self.save_iters_spin = QtWidgets.QSpinBox()
- self.save_iters_spin.setMinimum(1)
- self.save_iters_spin.setMaximum(int(self.max_iters))
- self.save_iters_spin.setValue(50000)
- self.save_iters_spin.valueChanged.connect(self.log_save_iters)
-
- # Max iterations
- maxiters_label = QtWidgets.QLabel("Maximum iterations")
- self.max_iters_spin = QtWidgets.QSpinBox()
- self.max_iters_spin.setMinimum(1)
- self.max_iters_spin.setMaximum(int(self.max_iters))
- self.max_iters_spin.setValue(100000)
- self.max_iters_spin.valueChanged.connect(self.log_max_iters)
-
- # Max number snapshots to keep
- snapkeep_label = QtWidgets.QLabel("Number of snapshots to keep")
- self.snapshots = QtWidgets.QSpinBox()
- self.snapshots.setMinimum(1)
- self.snapshots.setMaximum(100)
- self.snapshots.setValue(5)
- self.snapshots.valueChanged.connect(self.log_snapshots)
-
- layout.addWidget(shuffle_label, 0, 0)
- layout.addWidget(self.shuffle, 0, 1)
- layout.addWidget(dispiters_label, 0, 2)
- layout.addWidget(self.display_iters_spin, 0, 3)
- layout.addWidget(saveiters_label, 0, 4)
- layout.addWidget(self.save_iters_spin, 0, 5)
- layout.addWidget(maxiters_label, 0, 6)
- layout.addWidget(self.max_iters_spin, 0, 7)
- layout.addWidget(snapkeep_label, 0, 8)
- layout.addWidget(self.snapshots, 0, 9)
- # layout.addWidget()
-
- def log_display_iters(self, value):
- self.root.logger.info(f"Display iters set to {value}")
-
- def log_save_iters(self, value):
- self.root.logger.info(f"Save iters set to {value}")
-
- def log_max_iters(self, value):
- self.root.logger.info(f"Max iters set to {value}")
-
- def log_snapshots(self, value):
- self.root.logger.info(f"Max snapshots to keep set to {value}")
+ shuffle_label.setStyleSheet(f"margin: 0px 0px {row_margin}px 0px")
+ self._shuffle.setStyleSheet(f"margin: 0px 0px {row_margin}px 0px")
+ self._shuffle_display.setStyleSheet(f"margin: 0px 0px {row_margin}px 0px")
+
+ base_layout = _create_grid_layout(margins=(20, 0, 0, 0))
+ base_layout.addWidget(shuffle_label, 0, 0)
+ base_layout.addWidget(self._shuffle, 0, 1)
+ base_layout.addWidget(self._shuffle_display, 0, 2)
+ base_layout_widget = QtWidgets.QWidget()
+ base_layout_widget.setLayout(base_layout)
+ self.main_layout.addWidget(base_layout_widget)
+
+ for engine in Engine:
+ train_attributes = get_train_attributes(engine)
+
+ # Other parameters
+ param_layout = _create_grid_layout(margins=(20, 0, 0, 0))
+ param_layout.setVerticalSpacing(0)
+
+ self._attribute_kwargs[engine] = {}
+ row_index = 1
+ for row in train_attributes:
+ row_elements = []
+ if row.description is not None:
+ row_label = QtWidgets.QLabel(row.description)
+ row_label.setStyleSheet("font-weight: bold")
+ row_elements.append(row_label)
+ param_layout.addWidget(row_label, row_index, 0)
+ row_index += 1
+
+ for j, attribute in enumerate(row.attributes):
+ label = QtWidgets.QLabel(attribute.label)
+ spin_box = QtWidgets.QSpinBox()
+ spin_box.setMinimum(attribute.min)
+ spin_box.setMaximum(attribute.max)
+ spin_box.setValue(attribute.default)
+ spin_box.valueChanged.connect(
+ lambda new_val: self.log_attribute_change(attribute, new_val)
+ )
+ self._attribute_kwargs[engine][attribute.fn_key] = spin_box
+
+ # Pad below to create spacing with other rows
+ label.setStyleSheet(f"margin: 0px 0px {row_margin}px 0px")
+ spin_box.setStyleSheet(f"margin: 0px 0px {row_margin}px 0px")
+
+ row_elements.append(label)
+ row_elements.append(spin_box)
+
+ param_layout.addWidget(label, row_index, 2 * j)
+ param_layout.addWidget(spin_box, row_index, 2 * j + 1)
+
+ if row.show_when_cfg is not None:
+ self._rows_with_requirements.append(
+ (row.show_when_cfg, row_elements)
+ )
+
+ row_index += 1
+
+ layout_widget = QtWidgets.QWidget()
+ layout_widget.setLayout(param_layout)
+ self._attribute_layouts[engine] = layout_widget
+ if engine != self.root.engine:
+ layout_widget.hide()
+
+ self.main_layout.addWidget(layout_widget)
+
+ def log_attribute_change(self, attribute: IntTrainAttribute, value: int) -> None:
+ self.root.logger.info(f"{attribute.label} set to {value}")
def open_posecfg_editor(self):
editor = ConfigEditor(self.root.pose_cfg_path)
@@ -145,22 +224,24 @@ def open_posecfg_editor(self):
def train_network(self):
config = self.root.config
- shuffle = int(self.shuffle.value())
- max_snapshots_to_keep = int(self.snapshots.value())
- displayiters = int(self.display_iters_spin.value())
- saveiters = int(self.save_iters_spin.value())
- maxiters = int(self.max_iters_spin.value())
-
- deeplabcut.train_network(
- config,
- shuffle,
- gputouse=None,
- max_snapshots_to_keep=max_snapshots_to_keep,
- autotune=None,
- displayiters=displayiters,
- saveiters=saveiters,
- maxiters=maxiters,
- )
+ shuffle = int(self._shuffle.value())
+
+ kwargs = dict(gputouse=None, autotune=False)
+ for k, spin_box in self._attribute_kwargs[self.root.engine].items():
+ kwargs[k] = int(spin_box.value())
+ if self.root.engine == Engine.PYTORCH:
+ snapshot_to_start_training_from = (
+ self.snapshot_selection_widget.selected_snapshot
+ )
+ if snapshot_to_start_training_from is not None:
+ kwargs["snapshot_path"] = snapshot_to_start_training_from
+ detector_to_start_training_from = (
+ self.detector_snapshot_selection_widget.selected_snapshot
+ )
+ if detector_to_start_training_from is not None:
+ kwargs["detector_path"] = detector_to_start_training_from
+
+ compat.train_network(config, shuffle, **kwargs)
msg = QtWidgets.QMessageBox()
msg.setIcon(QtWidgets.QMessageBox.Information)
msg.setText("The network is now trained and ready to evaluate.")
@@ -175,3 +256,124 @@ def train_network(self):
msg.setWindowIcon(QIcon(self.logo))
msg.setStandardButtons(QtWidgets.QMessageBox.Ok)
msg.exec_()
+
+ @Slot(dict)
+ def _pose_cfg_change(self, pose_cfg: dict | None) -> None:
+ if pose_cfg is None:
+ return
+
+ for requirement, widgets in self._rows_with_requirements:
+ key, value = requirement
+ show = pose_cfg.get(key) == value
+ for w in widgets:
+ if show:
+ w.show()
+ else:
+ w.hide()
+
+ self._update_snapshot_selection_widgets_visibility()
+
+
+def get_train_attributes(engine: Engine) -> list[TrainAttributeRow]:
+ if engine == Engine.TF:
+ return [
+ TrainAttributeRow(
+ attributes=[
+ IntTrainAttribute(
+ label="Display iterations",
+ fn_key="displayiters",
+ default=1000,
+ min=1,
+ max=1000,
+ ),
+ IntTrainAttribute(
+ label="Number of snapshots to keep",
+ fn_key="max_snapshots_to_keep",
+ default=5,
+ min=1,
+ max=100,
+ ),
+ ],
+ ),
+ TrainAttributeRow(
+ attributes=[
+ IntTrainAttribute(
+ label="Maximum iterations",
+ fn_key="maxiters",
+ default=100_000,
+ min=1,
+ max=1_030_000,
+ ),
+ IntTrainAttribute(
+ label="Save iterations",
+ fn_key="saveiters",
+ default=50_000,
+ min=1,
+ max=50_000,
+ ),
+ ],
+ ),
+ ]
+ elif engine == Engine.PYTORCH:
+ return [
+ TrainAttributeRow(
+ attributes=[
+ IntTrainAttribute(
+ label="Display iterations",
+ fn_key="displayiters",
+ default=1_000,
+ min=1,
+ max=100_000,
+ ),
+ IntTrainAttribute(
+ label="Number of snapshots to keep",
+ fn_key="max_snapshots_to_keep",
+ default=5,
+ min=1,
+ max=100,
+ ),
+ ],
+ ),
+ TrainAttributeRow(
+ attributes=[
+ IntTrainAttribute(
+ label="Maximum epochs",
+ fn_key="epochs",
+ default=200,
+ min=1,
+ max=1000,
+ ),
+ IntTrainAttribute(
+ label="Save epochs",
+ fn_key="save_epochs",
+ default=50,
+ min=1,
+ max=250,
+ ),
+ ],
+ ),
+ TrainAttributeRow(
+ description="Detector parameters",
+ show_when_cfg=("method", "td"),
+ attributes=[
+ IntTrainAttribute(
+ label="Detector max epochs",
+ fn_key="detector_epochs",
+ default=200,
+ min=0,
+ max=1000,
+ tooltip="",
+ ),
+ IntTrainAttribute(
+ label="Detector save epochs",
+ fn_key="detector_save_epochs",
+ default=50,
+ min=1,
+ max=250,
+ tooltip="",
+ ),
+ ],
+ ),
+ ]
+
+ raise NotImplementedError(f"Unknown engine: {engine}")
diff --git a/deeplabcut/gui/tabs/video_editor.py b/deeplabcut/gui/tabs/video_editor.py
index 38be396f05..ecc0374c4d 100644
--- a/deeplabcut/gui/tabs/video_editor.py
+++ b/deeplabcut/gui/tabs/video_editor.py
@@ -48,6 +48,11 @@ def _set_page(self):
self.down_button.clicked.connect(self.downsample_videos)
self.main_layout.addWidget(self.down_button, alignment=Qt.AlignRight)
+ self.rotate_button = QtWidgets.QPushButton("Rotate")
+ self.rotate_button.setMinimumWidth(150)
+ self.rotate_button.clicked.connect(self.rotate_videos)
+ self.main_layout.addWidget(self.rotate_button, alignment=Qt.AlignRight)
+
self.trim_button = QtWidgets.QPushButton("Trim")
self.trim_button.setMinimumWidth(150)
self.trim_button.clicked.connect(self.trim_videos)
@@ -130,6 +135,21 @@ def log_video_stop(self, value):
def log_rotation_angle(self, value):
self.root.logger.info(f"Rotation angle set to {value}")
+ def rotate_videos(self):
+ if self.files:
+ for video in self.files:
+ if self.video_rotation.currentText() == "specific angle":
+ auxfun_videos.rotate_video(
+ video, self.rotation_angle.value(), "Arbitrary"
+ )
+ elif self.video_rotation.currentText() == "clockwise":
+ auxfun_videos.rotate_video(
+ video, 0, "Yes"
+ )
+ else:
+ self.root.logger.error("No videos selected...")
+
+
def trim_videos(self):
start = time.strftime("%H:%M:%S", time.gmtime(self.video_start.value()))
stop = time.strftime("%H:%M:%S", time.gmtime(self.video_stop.value()))
diff --git a/deeplabcut/gui/utils.py b/deeplabcut/gui/utils.py
index 5fb443045d..1de8de9c24 100644
--- a/deeplabcut/gui/utils.py
+++ b/deeplabcut/gui/utils.py
@@ -8,9 +8,10 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-from typing import Callable
+from typing import Callable, Tuple
from PySide6 import QtCore
+import re
class Worker(QtCore.QObject):
@@ -56,6 +57,17 @@ def stop_thread():
return worker, thread
+def parse_version(version: str) -> Tuple[int, int, int]:
+ """
+ Parses a version string into a tuple of (major, minor, patch).
+ """
+ match = re.search(r"(\d+)\.(\d+)\.(\d+)", version)
+ if match:
+ return tuple(int(part) for part in match.groups())
+ else:
+ raise ValueError(f"Invalid version format: {version}")
+
+
def is_latest_deeplabcut_version():
import json
import urllib.request
@@ -64,4 +76,4 @@ def is_latest_deeplabcut_version():
url = "https://pypi.org/pypi/deeplabcut/json"
contents = urllib.request.urlopen(url).read()
latest_version = json.loads(contents)["info"]["version"]
- return VERSION == latest_version, latest_version
+ return parse_version(VERSION) >= parse_version(latest_version), latest_version
diff --git a/deeplabcut/gui/widgets.py b/deeplabcut/gui/widgets.py
index 4e16ab83f7..af206baa6e 100644
--- a/deeplabcut/gui/widgets.py
+++ b/deeplabcut/gui/widgets.py
@@ -34,15 +34,16 @@
from deeplabcut.utils.auxfun_videos import VideoWriter
-def launch_napari(files=None):
+def launch_napari(files=None, plugin="napari-deeplabcut", stack=False):
viewer = napari.Viewer()
- # Automatically activate the napari-deeplabcut plugin
- for action in viewer.window.plugins_menu.actions():
- if "deeplabcut" in action.text():
- action.trigger()
- break
+ if plugin == "napari-deeplabcut":
+ # Automatically activate the napari-deeplabcut plugin
+ for action in viewer.window.plugins_menu.actions():
+ if "deeplabcut" in action.text():
+ action.trigger()
+ break
if files is not None:
- viewer.open(files, plugin="napari-deeplabcut")
+ viewer.open(files, plugin=plugin, stack=stack)
return viewer
@@ -451,7 +452,10 @@ class ConfigEditor(QtWidgets.QDialog):
def __init__(self, config, parent=None):
super(ConfigEditor, self).__init__(parent)
self.config = config
- if config.endswith("config.yaml"):
+ if (
+ config.endswith("config.yaml")
+ and not config.endswith("pytorch_config.yaml")
+ ):
self.read_func = auxiliaryfunctions.read_config
self.write_func = auxiliaryfunctions.write_config
else:
diff --git a/deeplabcut/gui/window.py b/deeplabcut/gui/window.py
index 2e81542523..3a8eac3529 100644
--- a/deeplabcut/gui/window.py
+++ b/deeplabcut/gui/window.py
@@ -15,24 +15,43 @@
from functools import cached_property
from pathlib import Path
from typing import List
+from urllib.error import URLError
import qdarkstyle
import deeplabcut
-from deeplabcut import auxiliaryfunctions, VERSION
+from deeplabcut import auxiliaryfunctions, VERSION, compat
+from deeplabcut.core.engine import Engine
from deeplabcut.gui import BASE_DIR, components, utils
from deeplabcut.gui.tabs import *
from deeplabcut.gui.widgets import StreamReceiver, StreamWriter
+from deeplabcut.utils.multiprocessing import call_with_timeout
from napari_deeplabcut import misc
-from PySide6.QtWidgets import QMessageBox, QMenu, QWidget, QMainWindow
+from PySide6.QtWidgets import (
+ QMessageBox,
+ QMenu,
+ QWidget,
+ QMainWindow,
+ QComboBox,
+ QLabel,
+ QSizePolicy,
+)
from PySide6 import QtCore
-from PySide6.QtGui import QIcon, QAction
+from PySide6.QtGui import QIcon, QAction, QPixmap
from PySide6 import QtWidgets, QtGui
from PySide6.QtCore import Qt, QTimer
def _check_for_updates(silent=True):
- is_latest, latest_version = utils.is_latest_deeplabcut_version()
- is_latest_plugin, latest_plugin_version = misc.is_latest_version()
+ try:
+ is_latest, latest_version = call_with_timeout(
+ utils.is_latest_deeplabcut_version, 5
+ )
+ is_latest_plugin, latest_plugin_version = call_with_timeout(
+ misc.is_latest_version, 5
+ )
+ except (URLError, TimeoutError): # Handle internet connectivity issues
+ is_latest = is_latest_plugin = True
+
if is_latest and is_latest_plugin:
if not silent:
msg = QtWidgets.QMessageBox(
@@ -54,9 +73,9 @@ def _check_for_updates(silent=True):
text=text,
)
msg.setIcon(QtWidgets.QMessageBox.Information)
- update_btn = msg.addButton("Update", msg.AcceptRole)
+ update_btn = msg.addButton("Update", QtWidgets.QMessageBox.AcceptRole)
msg.setDefaultButton(update_btn)
- _ = msg.addButton("Skip", msg.RejectRole)
+ _ = msg.addButton("Skip", QtWidgets.QMessageBox.RejectRole)
msg.exec_()
if msg.clickedButton() is update_btn:
subprocess.check_call([sys.executable, "-m", *command])
@@ -66,6 +85,9 @@ class MainWindow(QMainWindow):
config_loaded = QtCore.Signal()
video_type_ = QtCore.Signal(str)
video_files_ = QtCore.Signal(set)
+ engine_change = QtCore.Signal(Engine)
+ shuffle_change = QtCore.Signal(int)
+ shuffle_created = QtCore.Signal(int)
def __init__(self, app):
super(MainWindow, self).__init__()
@@ -84,6 +106,8 @@ def __init__(self, app):
self.videotype = "mp4"
self.files = set()
+ self._engine = Engine.PYTORCH
+
self.default_set()
self._generate_welcome_page()
@@ -103,6 +127,10 @@ def __init__(self, app):
self.receiver = StreamReceiver(self.writer.queue)
self.receiver.new_text.connect(self.print_to_status_bar)
+ # create logger to also log to the console
+ logging.basicConfig()
+ logging.getLogger("console").setLevel(logging.INFO)
+
self._progress_bar = QtWidgets.QProgressBar()
self._progress_bar.setMaximum(0)
self._progress_bar.hide()
@@ -111,6 +139,7 @@ def __init__(self, app):
def print_to_status_bar(self, text):
self.status_bar.showMessage(text)
self.status_bar.repaint()
+ logging.getLogger("console").info(text)
@property
def toolbar(self):
@@ -150,6 +179,43 @@ def cfg(self):
cfg = {}
return cfg
+ @property
+ def engine(self) -> Engine:
+ return self._engine
+
+ @engine.setter
+ def engine(self, e: Engine) -> None:
+ if self._engine == e:
+ return
+
+ if e == e.TF:
+ try:
+ import tensorflow
+ except ModuleNotFoundError as err:
+ msg = QtWidgets.QMessageBox()
+ msg.setIcon(QtWidgets.QMessageBox.Warning)
+ msg.setText("Cannot use the TensorFlow engine.")
+ msg.setInformativeText(
+ f"Error `{err}`\nCannot use the TensorFlow engine as TensorFlow "
+ "is not installed. To use it, install TensorFlow with\n"
+ " Windows/Linux:\n"
+ " pip install 'deeplabcut[tf]'\n"
+ " Apple Silicon:\n"
+ " pip install 'deeplabcut[apple_mchips]'\n\n"
+ "Please switch back to the PyTorch engine to use DeepLabCut, or install TensorFlow."
+ )
+
+ msg.setWindowTitle("Info")
+ msg.setMinimumWidth(900)
+ logo_dir = os.path.dirname(os.path.realpath("logo.png")) + os.path.sep
+ logo = logo_dir + "/assets/logo.png"
+ msg.setWindowIcon(QIcon(logo))
+ msg.setStandardButtons(QtWidgets.QMessageBox.Ok)
+ msg.exec_()
+
+ self._engine = e
+ self.engine_change.emit(e)
+
@property
def project_folder(self) -> str:
return self.cfg.get("project_path", os.path.expanduser("~/Desktop"))
@@ -175,19 +241,31 @@ def all_individuals(self) -> List:
@property
def pose_cfg_path(self) -> str:
try:
- return os.path.join(
- self.cfg["project_path"],
- auxiliaryfunctions.get_model_folder(
- self.cfg["TrainingFraction"][int(self.trainingset_index)],
- int(self.shuffle_value),
- self.cfg,
- ),
- "train",
- "pose_cfg.yaml",
+ return str(
+ compat.return_train_network_path(
+ self.config,
+ shuffle=int(self.shuffle_value),
+ trainingsetindex=int(self.trainingset_index),
+ modelprefix="",
+ )[0]
)
except FileNotFoundError:
return str(Path(deeplabcut.__file__).parent / "pose_cfg.yaml")
+ @property
+ def models_folder(self) -> str:
+ try:
+ return str(
+ compat.return_train_network_path(
+ self.config,
+ shuffle=int(self.shuffle_value),
+ trainingsetindex=int(self.trainingset_index),
+ modelprefix="",
+ )[2]
+ )
+ except FileNotFoundError:
+ return self.project_folder()
+
@property
def inference_cfg_path(self) -> str:
return os.path.join(
@@ -207,6 +285,7 @@ def update_cfg(self, text):
def update_shuffle(self, value):
self.shuffle_value = value
+ self.shuffle_change.emit(value)
self.logger.info(f"Shuffle set to {self.shuffle_value}")
@property
@@ -401,10 +480,51 @@ def update_menu_bar(self):
self.file_menu.removeAction(self.openAction)
def create_toolbar(self):
+ self.toolbar.clear()
self.toolbar.addAction(self.newAction)
self.toolbar.addAction(self.openAction)
self.toolbar.addAction(self.helpAction)
+ size_policy = QSizePolicy() # QtWidgets.QSizePolicy.Policy.Expanding
+ size_policy.setHorizontalPolicy(QSizePolicy.Policy.Expanding)
+ spacer = QLabel()
+ spacer.setSizePolicy(size_policy)
+ spacer.setStyleSheet("background: transparent;")
+
+ engine_label = QLabel()
+ engine_label.autoFillBackground()
+ engine_label.setText("Engine")
+ engine_label.setStyleSheet("background: transparent;")
+
+ engine_icon = QLabel()
+ engine_icon.setStyleSheet("background: transparent;")
+
+ def _update_icon(engine: str):
+ pixmap = QPixmap(f"deeplabcut/gui/media/dlc-{engine}.png")
+ engine_icon.setPixmap(
+ pixmap.scaled(56, 56, Qt.AspectRatioMode.KeepAspectRatio)
+ )
+
+ _update_icon("pt" if self.engine == Engine.PYTORCH else "tf")
+
+ engines = [engine for engine in Engine]
+
+ def _update_engine(index: int) -> None:
+ self.logger.info(f"Changed engine to {engines[index]}")
+ self.engine = engines[index]
+ _update_icon("pt" if self.engine == Engine.PYTORCH else "tf")
+
+ change_engine_widget = QComboBox()
+ change_engine_widget.addItems([e.aliases[0] for e in engines])
+ change_engine_widget.setFixedWidth(180)
+ change_engine_widget.currentIndexChanged.connect(_update_engine)
+ change_engine_widget.setCurrentIndex(engines.index(self.engine))
+
+ self.toolbar.addWidget(spacer)
+ self.toolbar.addWidget(engine_icon)
+ self.toolbar.addWidget(engine_label)
+ self.toolbar.addWidget(change_engine_widget)
+
def remove_action(self):
self.toolbar.removeAction(self.newAction)
self.toolbar.removeAction(self.openAction)
@@ -503,7 +623,9 @@ def add_tabs(self):
h1_description="DeepLabCut - Step 4. Create training dataset",
)
self.train_network = TrainNetwork(
- root=self, parent=None, h1_description="DeepLabCut - Train network"
+ root=self,
+ parent=None,
+ h1_description="DeepLabCut - Train network",
)
self.evaluate_network = EvaluateNetwork(
root=self,
@@ -557,8 +679,10 @@ def add_tabs(self):
self.tab_widget.addTab(self.video_editor, "Video editor (*)")
if not self.is_multianimal:
- self.refine_tracklets.setEnabled(False)
- self.unsupervised_id_tracking.setEnabled(self.is_transreid_available())
+ self.tab_widget.removeTab(
+ self.tab_widget.indexOf(self.unsupervised_id_tracking)
+ )
+ self.tab_widget.removeTab(self.tab_widget.indexOf(self.refine_tracklets))
self.setCentralWidget(self.tab_widget)
self.tab_widget.currentChanged.connect(self.refresh_active_tab)
diff --git a/deeplabcut/modelzoo/__init__.py b/deeplabcut/modelzoo/__init__.py
index 2dd1b06028..96c227b4b4 100644
--- a/deeplabcut/modelzoo/__init__.py
+++ b/deeplabcut/modelzoo/__init__.py
@@ -4,7 +4,8 @@
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
+from deeplabcut.modelzoo.weight_initialization import build_weight_init
diff --git a/deeplabcut/modelzoo/conversion_tables/conversion_table_quadruped.csv b/deeplabcut/modelzoo/conversion_tables/conversion_table_quadruped.csv
new file mode 100644
index 0000000000..5060943b21
--- /dev/null
+++ b/deeplabcut/modelzoo/conversion_tables/conversion_table_quadruped.csv
@@ -0,0 +1,40 @@
+ap10k,animalpose,stanforddogs,cheetah,horse,webapp,MasterName
+nose,nose,Nose,nose,Nose,nose,nose
+,,,,,,upper_jaw
+,,,,,,lower_jaw
+,,,,,,mouth_end_right
+,,,,,,mouth_end_left
+right_eye,right_eye,R_Eye,r_eye,Eye,right_eye,right_eye
+,right_ear,R_EarBase,,,right_ear,right_earbase
+,,R_EarTip,,,,right_earend
+,,,,,,right_antler_base
+,,,,,,right_antler_end
+left_eye,left_eye,L_Eye,l_eye,,left_eye,left_eye
+,left_ear,L_EarBase,,,left_ear,left_earbase
+,,L_EarTip,,,,left_earend
+,,,,,,left_antler_base
+,,,,,,left_antler_end
+neck,,,neck_base,,,neck_base
+,,,,,,neck_end
+,throat,Throat,,,throat,throat_base
+,,,,,,throat_end
+,withers,Withers,,Wither,withers,back_base
+,,,,,,back_end
+,,,spine,,,back_middle
+root_of_tail,tailbase,TailBase,tail_base,,tailset,tail_base
+,,TailEnd,tail_tip,,,tail_end
+left_shoulder,left_front_elbow,L_F_Elbow,l_shoulder,,left_front_elbow,front_left_thai
+,left_front_knee,L_F_Knee,l_front_knee,,,front_left_knee
+left_front_paw,left_front_paw,L_F_Paw,l_front_paw,Nearfrontfoot,left_front_paw,front_left_paw
+right_shoulder,right_front_elbow,R_F_Elbow,r_shoulder,Elbow,right_front_elbow,front_right_thai
+,right_front_knee,R_F_Knee,r_front_knee,,,front_right_knee
+right_front_paw,right_front_paw,R_F_Paw,r_front_paw,Offfrontfoot,right_front_paw,front_right_paw
+left_back_paw,left_back_paw,L_B_Paw,l_back_paw,Nearhindfoot,left_back_paw,back_left_paw
+left_hip,left_back_elbow,L_B_Elbow,l_hip,,left_back_stifle,back_left_thai
+right_hip,right_back_elbow,R_B_Elbow,r_hip,Stifle,right_back_stifle,back_right_thai
+left_knee,left_back_knee,L_B_Knee,l_back_knee,,,back_left_knee
+right_knee,right_back_knee,R_B_Knee,r_back_knee,,,back_right_knee
+right_back_paw,right_back_paw,R_B_Paw,r_back_paw,Offhindfoot,right_back_paw,back_right_paw
+,,,,,,belly_bottom
+,,,,,,body_middle_right
+,,,,,,body_middle_left
\ No newline at end of file
diff --git a/deeplabcut/modelzoo/conversion_tables/conversion_table_topview.csv b/deeplabcut/modelzoo/conversion_tables/conversion_table_topview.csv
new file mode 100644
index 0000000000..b02098abee
--- /dev/null
+++ b/deeplabcut/modelzoo/conversion_tables/conversion_table_topview.csv
@@ -0,0 +1,28 @@
+treadmill_ole,swimming_ole,openfield_ole,MackenzieMausHaus, ChanLab,Daniel3Mouse,dlc-openfield,EPM ,FST,LBD,OFT,Mostafizur,3CSI,BM,TwoWhiteMice,MasterName
+head,head,head,nose,Nose,snout,snout,nose,nose,nose,nose,snout,nose,nose,Nose,nose
+,,,leftearbase,Ear_left,leftear,leftear,earl,earl,earl,earl,leftear,earl,earl,Left_ear,left_ear
+,,,rightearbase,Ear_right,rightear,rightear,earr,earr,earr,earr,rightear,earr,earr,Right_ear,right_ear
+,,,lefteartip,,,,,,,,,,,,left_ear_tip
+,,,righteartip,,,,,,,,,,,,right_ear_tip
+,,,lefteye,,,,,,,,,,,,left_eye
+,,,righteye,,,,,,,,,,,,right_eye
+spine 1,spine 1,,spine1,,shoulder,,neck,neck,neck,neck,shoulder,neck,neck,,neck
+,,,spine2,,spine1,,,,,,spine1,,,,mid_back
+spine 2,spine 2,middle,spine3,Center,spine2,,bodycentre,bodycentre,bodycentre,bodycentre,spine2,bodycenter,bodycenter,Centroid,mouse_center
+,,,spine4,,spine3,,,,,,spine3,,,,mid_backend
+spine 3,spine 3,,spine5,,spine4,,,,,,spine4,,,,mid_backend2
+spine 4,spine 4,,spine6,,,,,,,,,,,,mid_backend3
+base ,base ,tailbase,tailbase,Tail_base,tailbase,tailbase,tailbase,tailbase,tailbase,tailbase,tailbase,tailbase,tailbase,Tail_base,tail_base
+,,,tail1,,tail1,,,,,,tail1,,,,tail1
+tail 25,tail 25,,tail2,,tail2,,,,,,tail2,,,,tail2
+,,,tail3,,,,tailcentre,tailcentre,tailcentre,tailcentre,,tailcenter,tailcenter,,tail3
+tail 50 ,tail 50,,tail4,, ,,,,,,,,,,tail4
+tail 75,tail 75,,tail5,, ,,,,,,,,,,tail5
+,,,leftshoulder,, ,,,,,,,,,,left_shoulder
+,,,leftside,, ,,bcl,bcl,bcl,bcl,,bcl,bcl,Left_lateral,left_midside
+,,,lefthip,Lateral_left,,,hipl,hipl,hipl,hipl,,hipl,hipl,,left_hip
+,,,rightshoulder,,,,,,,,,,,,right_shoulder
+,,,rightside,,,,bcr,bcr,bcr,bcr,,bcr,bcr,Right_lateral,right_midside
+,,,righthip,Lateral_right,,,hipr,hipr,hipr,hipr,,hipr,hipr,,right_hip
+tail 100,tail 100,tailtip,,Tail_end,tailend,,tailtip,tailtip,tailtip,tailtip,tailend,tailtip,tailtip,Tail_end,tail_end
+,,,,,,,headcentre,headcentre,headcentre,headcentre,,headcenter,headcenter,,head_midpoint
diff --git a/deeplabcut/modelzoo/generalized_data_converter/__init__.py b/deeplabcut/modelzoo/generalized_data_converter/__init__.py
new file mode 100644
index 0000000000..fb1e45d7ba
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/__init__.py
@@ -0,0 +1,11 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from .utils import add_skeleton, customized_colormap, create_modelprefix
diff --git a/deeplabcut/modelzoo/generalized_data_converter/conversion_table/__init__.py b/deeplabcut/modelzoo/generalized_data_converter/conversion_table/__init__.py
new file mode 100644
index 0000000000..7a3cd50142
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/conversion_table/__init__.py
@@ -0,0 +1,11 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from .conversion_table import get_conversion_table
diff --git a/deeplabcut/modelzoo/generalized_data_converter/conversion_table/conversion_table.py b/deeplabcut/modelzoo/generalized_data_converter/conversion_table/conversion_table.py
new file mode 100644
index 0000000000..a05d58d3a3
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/conversion_table/conversion_table.py
@@ -0,0 +1,153 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import warnings
+
+import numpy as np
+import pandas as pd
+
+
+class ConversionTableFromDict:
+ def __init__(self, raw_table_dict):
+ self.table_dict = raw_table_dict["conversion_table"]
+ self.master_keypoints = raw_table_dict["master_keypoints"]
+
+ def convert(self, kpt):
+ if kpt not in self.table_dict:
+ warnings.warn(
+ f"{kpt} is defined in src space but not appeared in the conversion table"
+ )
+ return None
+ else:
+ return self.table_dict[kpt]
+
+
+class ConversionTableFromCSV:
+ """
+ Base class only reads the table
+ """
+
+ def __init__(self, src_keypoints, table_path):
+ self.table_path = table_path
+
+ # sep removes leading and tailing white space
+ df = pd.read_csv(table_path, sep="\s*,\s*")
+
+ df.dropna(inplace=True, how="all")
+ # drop the row is MasterName has nan in the row
+ df = df.dropna(subset=["MasterName"])
+
+ self.df = df
+
+ self.src_keypoints = src_keypoints
+
+ kpt_list = df.to_numpy()
+
+ self.lookup_set = []
+
+ for i in range(len(kpt_list)):
+ kpts = np.array(kpt_list[i])
+ # remove nan
+
+ kpt_alias = set(kpts)
+
+ for k in list(kpt_alias):
+ if type(k) != str:
+ kpt_alias.remove(k)
+
+ self.lookup_set.append(kpt_alias)
+
+ target_keypoints = df["MasterName"].values
+
+ # target_keypoints = target_keypoints[~np.isnan(target_keypoints.values)]
+
+ self.master_keypoints = target_keypoints
+
+ # paired when they both exist
+
+ # following assumes that either it's 1vs.1 from src to target
+ # or 1 vs. 0
+ # it could be 1 vs. 2 in horse data
+ self.table = {}
+ for src_kpt in src_keypoints:
+ for target_kpt in target_keypoints:
+
+ src_kpt_id = self._search(src_kpt)
+ target_kpt_id = self._search(target_kpt)
+
+ if src_kpt_id == -1 or target_kpt_id == -1:
+ # if any one of them not exist in the set
+ # skip
+ continue
+ if src_kpt_id == target_kpt_id:
+ self.table[src_kpt] = target_kpt
+
+ self.check_inclusion()
+
+ def _search(self, key):
+ """
+ return -1 if not found
+ return kpt id if found
+
+ """
+ # [TODO] if it can be mapped to two, I can randomly return one
+ for kpt_id in range(len(self.lookup_set)):
+ if key in self.lookup_set[kpt_id]:
+ return kpt_id
+ return -1
+
+ def check_inclusion(self):
+ """
+ check if conversion table covers
+ every keypoint contained in src proj
+
+ """
+ count = 0
+ print("src keypoints")
+ print(self.src_keypoints)
+ for kpt in self.src_keypoints:
+ index = self._search(kpt)
+ if index == -1:
+ pass
+ else:
+ count += 1
+ print(f"{count}/{len(self.src_keypoints)} keypoints will be converted")
+
+ def convert(self, kpt):
+ if kpt not in self.table:
+ warnings.warn(
+ f"{kpt} is defined in src space but not appeared in the conversion table"
+ )
+ return None
+ else:
+ return self.table[kpt]
+
+ def get_subset(self, labname=""):
+
+ bodyparts = self.df[labname]
+
+ super_bodyparts = self.df["MasterName"]
+
+ ret = []
+
+ for bodypart in bodyparts:
+ if bodypart in self.table:
+ ret.append(self.table[bodypart])
+
+ return ret
+
+
+def get_conversion_table(keypoints=None, table_path="", table_dict=None):
+ if table_path is not None and keypoints is not None:
+ return ConversionTableFromCSV(keypoints, table_path)
+ elif table_dict:
+ return ConversionTableFromDict(table_dict)
+ else:
+ raise NotImplementedError("not supported")
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/__init__.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/__init__.py
new file mode 100644
index 0000000000..47e42a1bd7
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/__init__.py
@@ -0,0 +1,17 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from .ma_dlc import MaDLCPoseDataset
+from .multi import MultiSourceDataset
+from .coco import COCOPoseDataset
+from .materialize import mat_func_factory
+from .single_dlc import SingleDLCPoseDataset
+from .single_dlc_dataframe import SingleDLCDataFrame
+from .ma_dlc_dataframe import MaDLCDataFrame
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/base.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/base.py
new file mode 100644
index 0000000000..8ea478ce00
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/base.py
@@ -0,0 +1,325 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import copy
+import os
+import warnings
+
+import numpy as np
+
+from deeplabcut.modelzoo.generalized_data_converter.conversion_table import (
+ get_conversion_table,
+)
+from deeplabcut.modelzoo.generalized_data_converter.datasets.materialize import (
+ mat_func_factory,
+)
+
+
+def raw_2_imagename_with_id(image):
+ """
+ raw image data has filename and id.
+ we modify the imagename such that itis composed of
+ both original imagename and image id
+ """
+
+ file_name = image["file_name"]
+ image_name = file_name.split(os.sep)[-1]
+ pre, suffix = image_name.split(".")
+ image_id = image["id"]
+ return f"{pre}_{image_id}.{suffix}"
+
+
+def raw_2_imagename(image):
+ """
+ Only getting the imagename part from the image object
+ """
+
+ file_name = image["file_name"]
+ image_name = file_name.split(os.sep)[-1]
+ return image_name
+
+
+class BasePoseDataset:
+ """
+ Dual representation of generic and raw data. For classes that inherits this class,
+ the raw data is kept but generic data is populated so you have dual representation.
+ """
+
+ def __init__(self):
+ # generic data is what all the manipulation is based on
+ self.generic_train_images = []
+ self.generic_test_images = []
+ self.generic_train_annotations = []
+ self.generic_test_annotations = []
+ # These maps are very important for later analysis, including max_individuals
+ # and trace back the original dataset etc.
+ self.imageid2anno = {}
+ self.dataset2images = {}
+ self.imageid2filename = {}
+ self.imageid2datasetname = {}
+ self.datasetname2imageids = {}
+ # meta keeps information for later analysis
+ self.meta = {}
+ # if conversion_table is None, dataset is not yet converted to super keypoints
+ self.conversion_table = None
+
+ def _build_maps(self):
+ self.datasetname2imageids[self.meta["dataset_name"]] = set()
+
+ total_annotations = (
+ self.generic_train_annotations + self.generic_test_annotations
+ )
+ for anno in total_annotations:
+ image_id = anno["image_id"]
+ if image_id not in self.imageid2anno:
+ self.imageid2anno[image_id] = []
+ self.imageid2anno[image_id].append(anno)
+
+ total_images = self.generic_train_images + self.generic_test_images
+ for image in total_images:
+ image_id = image["id"]
+ self.imageid2datasetname[image_id] = self.meta["dataset_name"]
+ file_name = image["file_name"]
+ self.imageid2filename[image_id] = file_name
+ self.datasetname2imageids[self.meta["dataset_name"]].add(image_id)
+
+ # in DLC, even if you have more than one annotations in one image, it does not
+ # mean it's a multi animal project
+ max_num = 0
+ for k in self.imageid2anno:
+ max_num = max(len(self.imageid2anno[k]), max_num)
+
+ self.meta["max_individuals"] = max_num
+ self.meta["imageid2filename"] = self.imageid2filename
+
+ def filter_by_pattern(self, pattern):
+
+ keep_ids = []
+ keep_train_images = []
+ keep_test_images = []
+ for img in self.generic_train_images + self.generic_test_images:
+ print(img["file_name"])
+ if pattern in img["file_name"]:
+
+ image_id = img["id"]
+ keep_ids.append(image_id)
+
+ for image in self.generic_train_images:
+ if image["id"] in keep_ids:
+ keep_train_images.append(image["id"])
+
+ self.generic_train_images = keep_train_images
+
+ for image in self.generic_test_images:
+ if image["id"] in keep_ids:
+ keep_test_images.append(image["id"])
+
+ self.generic_test_images = keep_test_images
+
+ keep_train_annotations = []
+ keep_test_annotations = []
+
+ for anno in self.generic_train_annotations:
+ if anno["image_id"] in keep_ids:
+ keep_train_annotations.append(anno)
+
+ self.generic_train_annotations = keep_train_annotations
+
+ for anno in self.generic_test_annotations:
+ if anno["image_id"] in keep_ids:
+ keep_test_annotations.append(anno)
+
+ self.generic_test_annotations = keep_test_annotations
+
+ def summary(self):
+ print(f'Summary of dataset {self.meta["dataset_name"]}')
+ print("-------------")
+ print(f'max num individuals is {self.meta["max_individuals"]}')
+ print(f"total keypoints : {len(self.meta['categories']['keypoints'])}")
+ print(f"total train images : {len(self.generic_train_images)}")
+ print(f"total test images : {len(self.generic_test_images)}")
+ print(f"total train annotations : {len(self.generic_train_annotations)}")
+ print(f"total test annotations : {len(self.generic_test_annotations)}")
+ print("-------------")
+
+ def populate_generic(self):
+ raise NotImplementedError("Must implement this function")
+
+ def materialize(
+ self,
+ proj_root,
+ framework="coco",
+ deepcopy=False,
+ append_image_id=True,
+ no_image_copy=False,
+ ):
+ mat_func = mat_func_factory(framework)
+ self.meta["mat_datasets"] = {self.meta["dataset_name"]: self}
+ self.meta["imageid2datasetname"] = self.imageid2datasetname
+ kwargs = dict(deepcopy=deepcopy, append_image_id=append_image_id)
+ if framework == "coco":
+ kwargs["no_image_copy"] = no_image_copy
+
+ mat_func(
+ proj_root,
+ self.generic_train_images,
+ self.generic_test_images,
+ self.generic_train_annotations,
+ self.generic_test_annotations,
+ self.meta,
+ **kwargs,
+ )
+
+ def whether_anno_image_match(self, images, annotations):
+ """
+ Every image id should be annotated at least once
+ There should not be any image that is not being annotated
+ There should not be any annotation for beyond the set of given images
+ """
+
+ image_ids = set([image["id"] for image in images])
+
+ annotation_image_ids = set([anno["image_id"] for anno in annotations])
+
+ if image_ids != annotation_image_ids:
+ print("images-annotations", image_ids - annotation_image_ids)
+ print("len(images-annotatinos)", len(image_ids - annotation_image_ids))
+ print("annotations-images", annotation_image_ids - image_ids)
+ print("len(annotations-images)", len(annotation_image_ids - image_ids))
+ warnings.warn("annotation and image ids do not match")
+
+ def get_keypoints(self):
+ # TODO make sure it's always one element in a list
+ return self.meta["categories"]["keypoints"]
+
+ def _proj(self, annotations, conversion_table):
+
+ keypoints = self.get_keypoints()
+
+ kpt2index = {kpt: kpt_id for kpt_id, kpt in enumerate(keypoints)}
+
+ ret = []
+
+ master2src = {}
+ for kpt in keypoints:
+ conv_kpt = conversion_table.convert(kpt)
+ # sometimes a keypoint might not find its corresponding one from mastername
+ if conv_kpt is not None:
+ master2src[conv_kpt] = kpt
+
+ master_keypoints = conversion_table.master_keypoints
+
+ # need to change this in meta
+
+ for anno in annotations:
+ try:
+ kpts = anno["keypoints"]
+ except:
+ print(anno)
+
+ new_kpts = np.zeros(len(master_keypoints) * 3)
+ new_num_kpts = len(master_keypoints)
+
+ for master_kpt_id, master_kpt_name in enumerate(master_keypoints):
+ # check whether the dataset has the corresponding keypoint
+ if master_kpt_name not in master2src:
+ new_kpts[master_kpt_id * 3 : master_kpt_id * 3 + 3] = -1
+ continue
+
+ src_kpt_name = master2src[master_kpt_name]
+ src_kpt_id = kpt2index[src_kpt_name]
+ new_kpts[master_kpt_id * 3 : master_kpt_id * 3 + 3] = kpts[
+ src_kpt_id * 3 : src_kpt_id * 3 + 3
+ ]
+
+ # skipping empty frames after conversion
+ new_anno = copy.deepcopy(anno)
+ new_anno["keypoints"] = new_kpts
+ new_anno["num_keypoints"] = new_num_kpts
+ ret.append(new_anno)
+
+ return ret
+
+ def adjust_bbox_and_area(self):
+ """Called during conversion.
+
+ This is to remove the impact of keypoints that are potentially environmental
+ keypoints to the bbox and area calculation.
+ """
+ from .utils import calc_bboxes_from_keypoints
+
+ for annotation in (
+ self.generic_train_annotations + self.generic_test_annotations
+ ):
+ keypoints = annotation["keypoints"]
+ bbox_margin = 20
+
+ num_kpts = annotation["num_keypoints"]
+
+ keypoints = np.array(keypoints).reshape((num_kpts, 3))
+
+ mask = keypoints[:, 0] > 0
+ keypoints = keypoints[mask]
+
+ if keypoints.shape[0] == 0:
+ continue
+
+ xmin, ymin, xmax, ymax = calc_bboxes_from_keypoints(
+ [keypoints],
+ slack=bbox_margin,
+ clip=True,
+ )[0][:4]
+
+ w = xmax - xmin
+ h = ymax - ymin
+ area = w * h
+ bbox = np.nan_to_num([xmin, ymin, w, h])
+
+ if "bbox" not in annotation:
+ annotation["bbox"] = bbox
+ if "area" not in annotation:
+ annotation["area"] = area
+
+ def project_with_conversion_table(self, table_path="", table_dict=None):
+ """
+ Replace the generic annotations with those that are in superset keypoint space
+
+ """
+ print(f'Converting {self.meta["dataset_name"]}')
+
+ keypoints = self.get_keypoints()
+
+ self.conversion_table = get_conversion_table(
+ keypoints=keypoints, table_path=table_path, table_dict=table_dict
+ )
+
+ self.generic_train_annotations = self._proj(
+ self.generic_train_annotations, self.conversion_table
+ )
+
+ self.generic_test_annotations = self._proj(
+ self.generic_test_annotations, self.conversion_table
+ )
+
+ # all category id fixed to 1. So that it does not conflict with the background
+ # category id
+ for anno in self.generic_train_annotations + self.generic_test_annotations:
+ anno["category_id"] = 1
+
+ for img in self.generic_train_images + self.generic_test_images:
+ img["source_dataset"] = self.meta["dataset_name"]
+
+ self.adjust_bbox_and_area()
+ self.meta["categories"]["keypoints"] = self.conversion_table.master_keypoints
+ self.meta["categories"]["supercategory"] = "animal"
+ self.meta["categories"]["name"] = "superanimal"
+
+ # category id fixed to be 1, to avoid to conflict with background category id
+ self.meta["categories"]["id"] = 1
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/base_dlc.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/base_dlc.py
new file mode 100644
index 0000000000..d816449e4a
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/base_dlc.py
@@ -0,0 +1,115 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+import pickle
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base import BasePoseDataset
+from deeplabcut.utils import auxiliaryfunctions
+
+
+class BaseDLCPoseDataset(BasePoseDataset):
+
+ def __init__(self, proj_root, dataset_name, shuffle=1, modelprefix=""):
+ super(BaseDLCPoseDataset, self).__init__()
+
+ assert proj_root != None and dataset_name != None
+
+ self.meta["dataset_name"] = dataset_name
+ self.meta["proj_root"] = proj_root
+ self.meta["shuffle"] = shuffle
+ self.meta["modelprefix"] = modelprefix
+
+ self.proj_root = proj_root
+
+ if modelprefix:
+ config_file = os.path.join(self.proj_root, modelprefix + "_config.yaml")
+ else:
+ config_file = os.path.join(self.proj_root, "config.yaml")
+
+ cfg = auxiliaryfunctions.read_config(config_file)
+
+ task = cfg["Task"]
+
+ scorer = cfg["scorer"]
+
+ datasets_folder = os.path.join(
+ self.proj_root,
+ auxiliaryfunctions.GetTrainingSetFolder(cfg),
+ )
+
+ self.datasets_folder = datasets_folder
+
+ trainingFraction = int(cfg["TrainingFraction"][0] * 100)
+
+ path_dlc_collected = os.path.join(datasets_folder, f"CollectedData_{scorer}.h5")
+
+ path_dlc_document = os.path.join(
+ datasets_folder,
+ f"Documentation_data-{task}_{trainingFraction}shuffle{shuffle}.pickle",
+ )
+
+ df = pd.read_hdf(path_dlc_collected)
+
+ self.dlc_df = df
+
+ with open(path_dlc_document, "rb") as f:
+ document_data = pickle.load(f)
+
+ train_indices = document_data[1]
+ # index 2 is test indices
+ test_indices = document_data[2]
+
+ train_images = df.index[train_indices]
+ test_images = df.index[test_indices]
+
+ self.dlc_images = np.hstack([train_images, test_images])
+
+ df_train = df.loc[train_images]
+
+ df_test = df.loc[test_images]
+
+ self.coco_train = self._df2generic(df_train)
+
+ offset = len(self.coco_train["images"])
+
+ self.coco_test = self._df2generic(df_test, image_id_offset=offset)
+
+ self.populate_generic()
+
+ def _df2generic(self, df, image_id_offset=0):
+ raise NotImplementedError()
+
+ def populate_generic(self):
+
+ self.generic_train_images = self.coco_train["images"]
+ self.generic_test_images = self.coco_test["images"]
+ self.generic_train_annotations = self.coco_train["annotations"]
+ self.generic_test_annotations = self.coco_test["annotations"]
+
+ self.meta["categories"] = self.coco_test["categories"][0]
+
+ # to build maps for later analysis
+ self._build_maps()
+
+ print(f"Before checking trainset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_train_images, self.generic_train_annotations
+ )
+
+ print(f"Before checking testset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_test_images, self.generic_test_annotations
+ )
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/coco.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/coco.py
new file mode 100644
index 0000000000..7350a5fec0
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/coco.py
@@ -0,0 +1,90 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import copy
+import json
+import os
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base import BasePoseDataset
+
+
+class COCOPoseDataset(BasePoseDataset):
+ def __init__(
+ self,
+ proj_root,
+ dataset_name,
+ train_filename="train.json",
+ shuffle=None,
+ ):
+
+ super(COCOPoseDataset, self).__init__()
+
+ self.meta["dataset_name"] = dataset_name
+ self.meta["proj_root"] = proj_root
+
+ self.proj_root = proj_root
+ self.annotations_by_category = {}
+
+ self.train_json_obj = (
+ self._load_json(train_filename)
+ if shuffle is None
+ else self._load_json(
+ train_filename.replace(".json", f"_shuffle{shuffle}.json")
+ )
+ )
+ self.test_json_obj = (
+ self._load_json("test.json")
+ if shuffle is None
+ else self._load_json(f"test_shuffle{shuffle}.json")
+ )
+
+ self.populate_generic()
+
+ def _load_json(self, json_fn):
+ path = os.path.join(self.proj_root, "annotations", json_fn)
+ with open(path, "r") as f:
+ json_obj = json.load(f)
+ return json_obj
+
+ def populate_generic(self):
+
+ temp_train_images = copy.deepcopy(self.train_json_obj["images"])
+ temp_test_images = copy.deepcopy(self.test_json_obj["images"])
+
+ for image in temp_train_images + temp_test_images:
+ image_path = image["file_name"]
+ # if os.sep not in image_path:
+ # assuming the file_name is mmpose style, i.e. only the image name is stored
+ # so we need to add back absolute path
+
+ image["file_name"] = os.path.join(self.proj_root, "images", image_path)
+
+ self.generic_train_images = temp_train_images
+ self.generic_test_images = temp_test_images
+
+ self.generic_train_annotations = self.train_json_obj["annotations"]
+
+ self.generic_test_annotations = self.test_json_obj["annotations"]
+
+ self.meta["categories"] = self.test_json_obj["categories"][0]
+
+ self._build_maps()
+
+ print(f"Before checking trainset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_train_images, self.generic_train_annotations
+ )
+
+ print(f"Before checking testset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_test_images, self.generic_test_annotations
+ )
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc.py
new file mode 100644
index 0000000000..4a13c14f37
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc.py
@@ -0,0 +1,160 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base_dlc import (
+ BaseDLCPoseDataset,
+)
+from deeplabcut.modelzoo.generalized_data_converter.datasets.utils import (
+ calc_bboxes_from_keypoints,
+ read_image_shape_fast,
+)
+
+
+class MaDLCPoseDataset(BaseDLCPoseDataset):
+ def __init__(self, proj_root, dataset_name, shuffle=1, modelprefix=""):
+ super(MaDLCPoseDataset, self).__init__(
+ proj_root, dataset_name, shuffle=shuffle, modelprefix=modelprefix
+ )
+
+ def _df2generic(self, df, image_id_offset=0):
+
+ individuals = df.columns.get_level_values("individuals").unique().tolist()
+
+ unique_bpts = []
+
+ if "single" in individuals:
+ unique_bpts.extend(
+ df.xs("single", level="individuals", axis=1)
+ .columns.get_level_values("bodyparts")
+ .unique()
+ )
+ multi_bpts = (
+ df.xs(individuals[0], level="individuals", axis=1)
+ .columns.get_level_values("bodyparts")
+ .unique()
+ .tolist()
+ )
+
+ coco_categories = []
+
+ # assuming all individuals have the same name and same category id
+
+ individual = individuals[0]
+
+ category = {
+ "name": individual,
+ "id": 0,
+ "supercategory": "animal",
+ }
+
+ if individual == "single":
+ category["keypoints"] = unique_bpts
+ else:
+ category["keypoints"] = multi_bpts
+
+ coco_categories.append(category)
+
+ coco_images = []
+ coco_annotations = []
+
+ annotation_id = 0
+ image_id = -1
+ for _, file_name in enumerate(df.index):
+ data = df.loc[file_name]
+
+ # skipping all nan
+ if np.isnan(data.to_numpy()).all():
+ continue
+
+ image_id += 1
+
+ for individual_id, individual in enumerate(individuals):
+ category_id = 0
+ try:
+ kpts = (
+ data.xs(individual, level="individuals")
+ .to_numpy()
+ .reshape((-1, 2))
+ )
+ except:
+ # somehow there are duplicates. So only use the first occurrence
+ data = data.iloc[0]
+ kpts = (
+ data.xs(individual, level="individuals")
+ .to_numpy()
+ .reshape((-1, 2))
+ )
+
+ keypoints = np.zeros((len(kpts), 3))
+
+ keypoints[:, :2] = kpts
+
+ is_visible = ~pd.isnull(kpts).all(axis=1)
+
+ keypoints[:, 2] = np.where(is_visible, 2, 0)
+
+ num_keypoints = is_visible.sum()
+
+ bbox_margin = 20
+
+ xmin, ymin, xmax, ymax = calc_bboxes_from_keypoints(
+ [keypoints],
+ slack=bbox_margin,
+ clip=True,
+ )[0][:4]
+
+ w = xmax - xmin
+ h = ymax - ymin
+ area = w * h
+ bbox = np.nan_to_num([xmin, ymin, w, h])
+ keypoints = np.nan_to_num(keypoints.flatten())
+
+ annotation_id += 1
+ annotation = {
+ "image_id": image_id + image_id_offset,
+ "num_keypoints": num_keypoints,
+ "keypoints": keypoints,
+ "id": annotation_id,
+ "category_id": category_id,
+ "area": area,
+ "bbox": bbox,
+ "iscrowd": 0,
+ }
+ if np.sum(keypoints) != 0:
+ coco_annotations.append(annotation)
+
+ # I think width and height are important
+
+ if isinstance(file_name, tuple):
+ image_path = os.path.join(self.proj_root, *list(file_name))
+ else:
+ image_path = os.path.join(self.proj_root, file_name)
+
+ _, height, width = read_image_shape_fast(image_path)
+
+ image = {
+ "file_name": image_path,
+ "width": width,
+ "height": height,
+ "id": image_id + image_id_offset,
+ }
+ coco_images.append(image)
+
+ ret_obj = {
+ "images": coco_images,
+ "annotations": coco_annotations,
+ "categories": coco_categories,
+ }
+ return ret_obj
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc_dataframe.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc_dataframe.py
new file mode 100644
index 0000000000..0278d39a54
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/ma_dlc_dataframe.py
@@ -0,0 +1,280 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
+ parse_video_filenames,
+)
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base import BasePoseDataset
+from deeplabcut.modelzoo.generalized_data_converter.datasets.utils import (
+ calc_bboxes_from_keypoints,
+ read_image_shape_fast,
+)
+from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions, conversioncode
+
+
+def merge_annotateddatasets(cfg):
+ """
+ Merges all the h5 files for all labeled-datasets (from individual videos).
+
+ This is a bit of a mess because of cross platform compatibility.
+
+ Within platform comp. is straightforward. But if someone labels on windows and wants to train on a unix cluster or colab...
+ """
+ AnnotationData = []
+ data_path = Path(os.path.join(cfg["project_path"], "labeled-data"))
+ videos = cfg["video_sets"].keys()
+ video_filenames = parse_video_filenames(videos)
+ for filename in video_filenames:
+ file_path = os.path.join(
+ data_path / filename, f'CollectedData_{cfg["scorer"]}.h5'
+ )
+ try:
+ data = pd.read_hdf(file_path)
+ conversioncode.guarantee_multiindex_rows(data)
+ if data.columns.levels[0][0] != cfg["scorer"]:
+ print(
+ f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)"
+ )
+ continue
+ AnnotationData.append(data)
+ except FileNotFoundError:
+ print(file_path, " not found (perhaps not annotated).")
+
+ if not len(AnnotationData):
+ print(
+ "Annotation data was not found by splitting video paths (from config['video_sets']). An alternative route is taken..."
+ )
+ AnnotationData = conversioncode.merge_windowsannotationdataONlinuxsystem(cfg)
+ if not len(AnnotationData):
+ print("No data was found!")
+ return
+
+ AnnotationData = pd.concat(AnnotationData).sort_index()
+ # When concatenating DataFrames with misaligned column labels,
+ # all sorts of reordering may happen (mainly depending on 'sort' and 'join')
+ # Ensure the 'bodyparts' level agrees with the order in the config file.
+ if cfg.get("multianimalproject", False):
+ (
+ _,
+ uniquebodyparts,
+ multianimalbodyparts,
+ ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
+ bodyparts = multianimalbodyparts + uniquebodyparts
+ else:
+ bodyparts = cfg["bodyparts"]
+ AnnotationData = AnnotationData.reindex(
+ bodyparts, axis=1, level=AnnotationData.columns.names.index("bodyparts")
+ )
+
+ return AnnotationData
+
+
+class MaDLCDataFrame(BasePoseDataset):
+
+ def __init__(self, proj_root, dataset_name):
+ super(MaDLCDataFrame, self).__init__()
+ assert proj_root != None and dataset_name != None
+ self.proj_root = proj_root
+ self.dataset_name = dataset_name
+ self.meta["dataset_name"] = dataset_name
+ self.meta["proj_root"] = proj_root
+ config_path = Path(proj_root) / "config.yaml"
+ # read config
+ cfg = auxiliaryfunctions.read_config(config_path)
+ # get the train folder
+
+ Data = merge_annotateddatasets(
+ cfg,
+ )
+
+ # now with this data, we construct necessary generic data
+
+ self.dlc_df = Data
+
+ images = self.dlc_df.index
+
+ ratio = 0.9
+
+ df_train = self.dlc_df.iloc[: int(len(images) * ratio)]
+ df_test = self.dlc_df.iloc[int(len(images) * ratio) :]
+
+ self.coco_train = self._df2generic(df_train)
+
+ offset = len(self.coco_train["images"])
+
+ self.coco_test = self._df2generic(df_test, image_id_offset=offset)
+
+ self.populate_generic()
+
+ def populate_generic(self):
+
+ self.generic_train_images = self.coco_train["images"]
+ self.generic_test_images = self.coco_test["images"]
+ self.generic_train_annotations = self.coco_train["annotations"]
+ self.generic_test_annotations = self.coco_test["annotations"]
+
+ self.meta["categories"] = self.coco_test["categories"][0]
+
+ # to build maps for later analysis
+ self._build_maps()
+
+ print(f"Before checking trainset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_train_images, self.generic_train_annotations
+ )
+
+ print(f"Before checking testset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_test_images, self.generic_test_annotations
+ )
+
+ def _df2generic(self, df, image_id_offset=0):
+
+ individuals = df.columns.get_level_values("individuals").unique().tolist()
+
+ unique_bpts = []
+
+ if "single" in individuals:
+ unique_bpts.extend(
+ df.xs("single", level="individuals", axis=1)
+ .columns.get_level_values("bodyparts")
+ .unique()
+ )
+ multi_bpts = (
+ df.xs(individuals[0], level="individuals", axis=1)
+ .columns.get_level_values("bodyparts")
+ .unique()
+ .tolist()
+ )
+
+ coco_categories = []
+
+ # assuming all individuals have the same name and same category id
+
+ individual = individuals[0]
+
+ category = {
+ "name": individual,
+ "id": 0,
+ "supercategory": "animal",
+ }
+
+ if individual == "single":
+ category["keypoints"] = unique_bpts
+ else:
+ category["keypoints"] = multi_bpts
+
+ coco_categories.append(category)
+
+ coco_images = []
+ coco_annotations = []
+
+ annotation_id = 0
+ image_id = -1
+ for _, file_name in enumerate(df.index):
+ data = df.loc[file_name]
+
+ # skipping all nan
+ if np.isnan(data.to_numpy()).all():
+ continue
+
+ image_id += 1
+
+ for individual_id, individual in enumerate(individuals):
+ category_id = 0
+ try:
+ kpts = (
+ data.xs(individual, level="individuals")
+ .to_numpy()
+ .reshape((-1, 2))
+ )
+ except:
+ # somehow there are duplicates. So only use the first occurrence
+ data = data.iloc[0]
+ kpts = (
+ data.xs(individual, level="individuals")
+ .to_numpy()
+ .reshape((-1, 2))
+ )
+
+ keypoints = np.zeros((len(kpts), 3))
+
+ keypoints[:, :2] = kpts
+
+ is_visible = ~pd.isnull(kpts).all(axis=1)
+
+ keypoints[:, 2] = np.where(is_visible, 2, 0)
+
+ num_keypoints = is_visible.sum()
+
+ bbox_margin = 20
+
+ xmin, ymin, xmax, ymax = calc_bboxes_from_keypoints(
+ [keypoints],
+ slack=bbox_margin,
+ clip=True,
+ )[0][:4]
+
+ w = xmax - xmin
+ h = ymax - ymin
+ area = w * h
+ bbox = np.nan_to_num([xmin, ymin, w, h])
+ keypoints = np.nan_to_num(keypoints.flatten())
+
+ annotation_id += 1
+ annotation = {
+ "image_id": image_id + image_id_offset,
+ "num_keypoints": num_keypoints,
+ "keypoints": keypoints,
+ "id": annotation_id,
+ "category_id": category_id,
+ "area": area,
+ "bbox": bbox,
+ "iscrowd": 0,
+ }
+ if np.sum(keypoints) != 0:
+ coco_annotations.append(annotation)
+
+ # I think width and height are important
+
+ if isinstance(file_name, tuple):
+ image_path = os.path.join(self.proj_root, *list(file_name))
+ else:
+ image_path = os.path.join(self.proj_root, file_name)
+
+ _, height, width = read_image_shape_fast(image_path)
+
+ image = {
+ "file_name": image_path,
+ "width": width,
+ "height": height,
+ "id": image_id + image_id_offset,
+ }
+ coco_images.append(image)
+
+ ret_obj = {
+ "images": coco_images,
+ "annotations": coco_annotations,
+ "categories": coco_categories,
+ }
+ return ret_obj
+
+
+if __name__ == "__main__":
+ dataset = MaDLCDataFrame("/mnt/md0/shaokai/daniel3mouse", "3mouse")
+ dataset.summary()
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/materialize.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/materialize.py
new file mode 100644
index 0000000000..63211b8f35
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/materialize.py
@@ -0,0 +1,797 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import json
+import os
+import pickle
+import shutil
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import scipy.io as sio
+import yaml
+
+import deeplabcut.compat as compat
+from deeplabcut.generate_training_dataset.multiple_individuals_trainingsetmanipulation import (
+ create_multianimaltraining_dataset,
+ format_multianimal_training_data,
+)
+from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
+ create_training_dataset,
+)
+from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
+ format_training_data as format_single_training_data,
+)
+from deeplabcut.utils import auxiliaryfunctions
+
+
+def get_filename(filename):
+ if type(filename) == tuple:
+ filename = os.path.join(*filename)
+ return filename
+
+
+def modify_train_test_cfg(config_path, shuffle=1, modelprefix=""):
+ # get train_cfg from main cfg
+ # use dlcr net
+ # use gradient masking
+ # set batch size as 8
+ trainposeconfigfile, testposeconfigfile, snapshotfolder = (
+ compat.return_train_network_path(
+ config_path, shuffle=shuffle, modelprefix=modelprefix, trainingsetindex=0
+ )
+ )
+
+ train_cfg = auxiliaryfunctions.read_plainconfig(trainposeconfigfile)
+ train_cfg["multi_stage"] = True
+ train_cfg["batch_size"] = 8
+ train_cfg["gradient_masking"] = True
+
+ auxiliaryfunctions.write_plainconfig(trainposeconfigfile, train_cfg)
+
+ test_cfg = auxiliaryfunctions.read_plainconfig(testposeconfigfile)
+ test_cfg["multi_stage"] = True
+ test_cfg["batch_size"] = 8
+ test_cfg["gradient_masking"] = True
+
+ auxiliaryfunctions.write_plainconfig(testposeconfigfile, test_cfg)
+
+
+class NpEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ else:
+ return super(NpEncoder, self).default(obj)
+
+
+class SingleDLC_config:
+ def __init__(self):
+ Task = "" # could be dataset name
+ project_path = ""
+ scorer = "" # random stuff
+ date = "" # random stuff
+ video_sets = "" # has to be used for labeled data
+ skeleton = "" # could be arbitrary
+ bodyparts = "" # either single or multi
+ start = 0 # not sure
+ stop = 1 # not sure
+ numframes2pick = 42 # does not matter
+ skeleton_color = "black"
+ pcutoff = 0.6
+ dotsize = 8
+ alphavalue = 0.7
+ colormap = "rainbow"
+ TrainingFraction = "" # need to be filled correctly
+ iteration = 0
+ default_net_type = "resnet_50"
+ default_augmenter = "imgaug"
+ snapshotindex = -1
+ batch_size = 8
+ cropping = False
+ croppedtraining = False
+ multianimalproject = False
+ uniquebodyparts = []
+ x1 = 0
+ x2 = 640
+ y1 = 277
+ y2 = 624
+ corer2move2 = [50, 50]
+ move2corner = True
+ identity = False
+ self.cfg = {
+ k: v for k, v in vars().items() if "__" not in k and "self" not in k
+ }
+
+ def create_cfg(self, proj_root, kwargs):
+ self.cfg.update(kwargs)
+ with open(os.path.join(proj_root, "config.yaml"), "w") as f:
+ yaml.dump(self.cfg, f)
+
+
+class MaDLC_config:
+ def __init__(self):
+ """
+ Plain text only for generating templates
+ Some variables can be configured by the user later
+ """
+
+ Task = "" # could be dataset name
+ project_path = ""
+ scorer = "" # random stuff
+ date = "" # random stuff
+ video_sets = "" # has to be used for labeled data
+ individuals = "" # number of individuals
+ multianimalbodyparts = "" # keypoints
+ skeleton = "" # could be arbitrary
+ bodyparts = "" # either single or multi
+ start = 0 # not sure
+ stop = 1 # not sure
+ numframes2pick = 42 # does not matter
+ skeleton_color = "black"
+ pcutoff = 0.6
+ dotsize = 8
+ alphavalue = 0.7
+ colormap = "rainbow"
+ TrainingFraction = "" # need to be filled correctly
+ iteration = 0
+ default_net_type = "resnet_50"
+ default_augmenter = "multi-animal-imgaug"
+ snapshotindex = -1
+ batch_size = 8
+ cropping = False
+ croppedtraining = True
+ multianimalproject = True
+ uniquebodyparts = []
+ x1 = 0
+ x2 = 640
+ y1 = 277
+ y2 = 624
+ corer2move2 = [50, 50]
+ move2corner = True
+ identity = False
+ self.cfg = {
+ k: v for k, v in vars().items() if "__" not in k and "self" not in k
+ }
+
+ def create_cfg(self, proj_root, kwargs):
+ self.cfg.update(kwargs)
+ with open(os.path.join(proj_root, "config.yaml"), "w") as f:
+ yaml.dump(self.cfg, f)
+
+
+def _generic2madlc(
+ proj_root,
+ train_images,
+ test_images,
+ train_annotations,
+ test_annotations,
+ meta,
+ deepcopy=False,
+ full_image_path=True,
+ append_image_id=True,
+):
+ """
+ Within DeepLabCut, if we don't explicitly call deeplabcut.create_traindataset(), the train and test split might just be arbitrarily messed up. So here we need to calculate train and test indices to
+
+ Args:
+ proj_root where to materialize the data
+
+ """
+
+ assert full_image_path, "DLC wants full image path"
+
+ os.makedirs(os.path.join(proj_root, "labeled-data"), exist_ok=True)
+
+ cfg_template = MaDLC_config()
+
+ individuals = [f"individual{i}" for i in range(meta["max_individuals"])]
+
+ bodyparts = meta["categories"]["keypoints"]
+
+ scorer = "maDLC_scorer"
+ # this line is taken from dlc's multi animal dataset creation function
+ train_fraction = round(
+ len(train_images) * 1.0 / (len(train_images) + len(test_images)), 2
+ )
+
+ # need to fake a video path
+ # let's use individual dataset names as fake video name
+ # merged_dataset_name = '_'.join(meta['mat_datasets'])
+ video_sets = {
+ f"{dataset_name}.mp4": {"crop": "0, 400, 0, 400"}
+ for dataset_name in meta["mat_datasets"]
+ }
+
+ modify_dict = dict(
+ Task=meta["dataset_name"],
+ project_path=proj_root,
+ individuals=individuals,
+ scorer=scorer,
+ date="March30",
+ video_sets=video_sets,
+ bodyparts="MULTI!",
+ TrainingFraction=[train_fraction],
+ multianimalbodyparts=list(bodyparts),
+ )
+
+ cfg_template.create_cfg(proj_root, modify_dict)
+ # what's special in dlc or madlc creation is that we will need to
+ # use dlc's code for creating the project structure
+ # because you don't want to write your own. It's a lot of lines of code
+ # But at least we can focus on labeled-data
+
+ imageid2datasetname = meta["imageid2datasetname"]
+
+ for dataset_name in meta["mat_datasets"]:
+ os.makedirs(
+ os.path.join(proj_root, "labeled-data", dataset_name), exist_ok=True
+ )
+
+ # also, to make sure the split is right, we will have to pass the right indices
+
+ columnindex = pd.MultiIndex.from_product(
+ [[scorer], individuals, bodyparts, ["x", "y"]],
+ names=["scorer", "individuals", "bodyparts", "coords"],
+ )
+
+ # it's important to put train first so the train_fraction parameter can work correctly
+ total_images = train_images + test_images
+ total_annotations = train_annotations + test_annotations
+
+ # DLC uses relative dest as index into dataframe
+ imageid2relativedest = {}
+ count = 0
+ for image in total_images:
+ image_id = image["id"]
+ file_name = image["file_name"]
+ image_name = file_name.split(os.sep)[-1]
+ pre, suffix = image_name.split(".")
+ if append_image_id == True:
+ dest_image_name = f"{pre}_{image_id}.{suffix}"
+ else:
+ dest_image_name = image_name
+ # the generic data has original pointers to images in the original folders
+ # Here, we have to change the image name and location of these to fit corresponding framework's convention
+
+ dataset_name = imageid2datasetname[image_id]
+
+ dest = os.path.join(proj_root, "labeled-data", dataset_name, dest_image_name)
+ if deepcopy:
+ shutil.copy(file_name, dest)
+ else:
+ try:
+ os.symlink(file_name, dest)
+ except Exception as e:
+ pass
+
+ relative_dest = os.path.join("labeled-data", dataset_name, dest_image_name)
+
+ imageid2relativedest[image_id] = relative_dest
+
+ temp_count = 0
+ for dataset_name, dataset in meta["mat_datasets"].items():
+
+ dataset_total_images = (
+ dataset.generic_train_images + dataset.generic_test_images
+ )
+ dataset_total_annotations = (
+ dataset.generic_train_annotations + dataset.generic_test_annotations
+ )
+
+ dataset_index = []
+
+ for image in dataset_total_images:
+ image_id = image["id"]
+ relative_dest = imageid2relativedest[image_id]
+ dataset_index.append(relative_dest)
+
+ raw_data = np.zeros((len(dataset_total_images), len(columnindex))) * np.nan
+ df = pd.DataFrame(raw_data, columns=columnindex, index=dataset_index)
+ # so we know where to put the next annotation if there are multiple individuals in that image
+ imageid2filledindividualcount = {}
+
+ image_ids = []
+ for anno in dataset_total_annotations:
+ keypoints = anno["keypoints"]
+ image_id = anno["image_id"]
+ image_ids.append(image_id)
+ if image_id not in imageid2filledindividualcount:
+ imageid2filledindividualcount[image_id] = 0
+ else:
+ imageid2filledindividualcount[image_id] += 1
+ individual_id = imageid2filledindividualcount[image_id]
+
+ file_name = imageid2relativedest[image_id]
+ for kpt_id, kpt_name in enumerate(meta["categories"]["keypoints"]):
+ coord = keypoints[3 * kpt_id : 3 * kpt_id + 3]
+ # note dlc does not yet have visibility flag
+ # need to be careful here to assign right keypoints to right people
+ if coord[0] > 0 and coord[1] > 0:
+ # leave them to NaN if values are 0
+ df.loc[file_name][
+ scorer, f"individual{individual_id}", kpt_name, "x"
+ ] = coord[0]
+ df.loc[file_name][
+ scorer, f"individual{individual_id}", kpt_name, "y"
+ ] = coord[1]
+ elif coord[2] == -1:
+ df.loc[file_name][
+ scorer, f"individual{individual_id}", kpt_name, "x"
+ ] = -1
+ df.loc[file_name][
+ scorer, f"individual{individual_id}", kpt_name, "y"
+ ] = -1
+ df.to_hdf(
+ os.path.join(
+ proj_root, "labeled-data", dataset_name, f"CollectedData_{scorer}.h5"
+ ),
+ key="df_with_missing",
+ mode="w",
+ )
+ # paf_graph default as None. But I am not sure how to do better
+ create_multianimaltraining_dataset(
+ os.path.join(proj_root, "config.yaml"), paf_graph=None
+ )
+
+ # dlc's merge_annotation messes up my indices, so I will need to overwrite the documentation file
+ # I could have done it in a more elegant way if I could modify part of DLC source code, but for backward compatibility reasons, overriding documentation is smarter
+
+ config_path = os.path.join(proj_root, "config.yaml")
+
+ cfg = auxiliaryfunctions.read_config(config_path)
+
+ train_folder = os.path.join(proj_root, auxiliaryfunctions.GetTrainingSetFolder(cfg))
+
+ datafilename, metafilename = auxiliaryfunctions.GetDataandMetaDataFilenames(
+ train_folder, train_fraction, 1, cfg
+ )
+
+ modify_train_test_cfg(config_path)
+
+ dlc_df = pd.read_hdf(os.path.join(train_folder, f"CollectedData_{scorer}.h5"))
+
+ # I strip off video info from the naming. For horse10, I need to get it back
+ parent_trace = {}
+
+ def _filter(image):
+ file_name = image["file_name"]
+
+ image_name = file_name.split(os.sep)[-1]
+ video_folder = file_name.split(os.sep)[-2]
+ pre, suffix = image_name.split(".")
+ image_id = image["id"]
+ if append_image_id:
+ ret = f"{pre}_{image_id}.{suffix}"
+ else:
+ ret = image_name
+ parent_trace[ret] = video_folder
+ return ret
+
+ _filter_train_images = list(map(_filter, train_images))
+ _filter_test_images = list(map(_filter, test_images))
+
+ with open(os.path.join(train_folder, "parent_trace.pickle"), "wb") as f:
+ pickle.dump(parent_trace, f)
+
+ trainIndices = [
+ idx
+ for idx, image in enumerate(dlc_df.index)
+ if get_filename(image).split(os.sep)[-1] in _filter_train_images
+ ]
+ testIndices = [
+ idx
+ for idx, image in enumerate(dlc_df.index)
+ if get_filename(image).split(os.sep)[-1] in _filter_test_images
+ ]
+
+ with open(metafilename, "rb") as f:
+ metafile = pickle.load(f)
+
+ metafile[1] = trainIndices
+ metafile[2] = testIndices
+
+ with open(metafilename, "wb") as f:
+ pickle.dump(metafile, f)
+
+ # need to overwrite the data pickle file too
+
+ nbodyparts = len(bodyparts)
+
+ if "individuals" not in dlc_df.columns.names:
+ old_idx = dlc_df.columns.to_frame()
+ old_idx.insert(0, "individuals", "")
+ dlc_df.columns = pd.MultiIndex.from_frame(old_idx)
+
+ data = format_multianimal_training_data(dlc_df, trainIndices, cfg["project_path"])
+
+ datafilename = datafilename.split(".mat")[0] + ".pickle"
+
+ print(f"overwriting data file {datafilename}")
+
+ with open(os.path.join(proj_root, datafilename), "wb") as f:
+
+ pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
+
+
+def _generic2sdlc(
+ proj_root,
+ train_images,
+ test_images,
+ train_annotations,
+ test_annotations,
+ meta,
+ deepcopy=False,
+ full_image_path=True,
+ append_image_id=True,
+):
+
+ assert full_image_path, "DLC wants full image path"
+
+ os.makedirs(os.path.join(proj_root, "labeled-data"), exist_ok=True)
+
+ cfg_template = SingleDLC_config()
+
+ bodyparts = meta["categories"]["keypoints"]
+ scorer = "singleDLC_scorer"
+
+ train_fraction = round(
+ len(train_images) * 1.0 / (len(train_images) + len(test_images)), 2
+ )
+
+ # need to fake a video path
+ # let's use individual dataset names as fake video name
+
+ video_sets = {
+ f"{dataset_name}.mp4": {"crop": "0, 400, 0, 400"}
+ for dataset_name in meta["mat_datasets"].keys()
+ }
+
+ modify_dict = dict(
+ Task=meta["dataset_name"],
+ project_path=proj_root,
+ scorer=scorer,
+ date="March30",
+ bodyparts=list(bodyparts),
+ video_sets=video_sets,
+ TrainingFraction=[train_fraction],
+ )
+
+ cfg_template.create_cfg(proj_root, modify_dict)
+
+ imageid2datasetname = meta["imageid2datasetname"]
+
+ for dataset_name in meta["mat_datasets"]:
+ os.makedirs(
+ os.path.join(proj_root, "labeled-data", dataset_name), exist_ok=True
+ )
+
+ columnindex = pd.MultiIndex.from_product(
+ [[scorer], bodyparts, ["x", "y"]], names=["scorer", "bodyparts", "coords"]
+ )
+
+ total_images = train_images + test_images
+ total_annotations = train_annotations + test_annotations
+
+ # DLC uses relative dest as index
+ imageid2relativedest = {}
+
+ for image in total_images:
+ imageid = image["id"]
+ filename = image["file_name"]
+ datasetname = imageid2datasetname[imageid]
+ count = 0
+ for image in total_images:
+ image_id = image["id"]
+ file_name = image["file_name"]
+
+ image_name = file_name.split(os.sep)[-1]
+ pre, suffix = image_name.split(".")
+
+ if append_image_id == True:
+ dest_image_name = f"{pre}_{image_id}.{suffix}"
+ else:
+ dest_image_name = image_name
+ # the generic data has original pointers to images in the original folders
+ # Here, we have to change the image name and location of these to fit corresponding framework's convention
+
+ dataset_name = imageid2datasetname[image_id]
+
+ dest = os.path.join(proj_root, "labeled-data", dataset_name, dest_image_name)
+ if deepcopy:
+ shutil.copy(file_name, dest)
+ else:
+ try:
+ os.symlink(file_name, dest)
+ except:
+ pass
+
+ if dataset_name == "AwA-Pose":
+ count += 1
+
+ relative_dest = os.path.join("labeled-data", dataset_name, dest_image_name)
+ imageid2relativedest[image_id] = relative_dest
+
+ # so we know where to put the next annotation if there are multiple individuals in that image
+
+ for dataset_name, dataset in meta["mat_datasets"].items():
+
+ dataset_total_images = (
+ dataset.generic_train_images + dataset.generic_test_images
+ )
+ dataset_total_annotations = (
+ dataset.generic_train_annotations + dataset.generic_test_annotations
+ )
+
+ dataset_index = []
+ freq = {}
+ for image in dataset_total_images:
+ filename = image["file_name"]
+
+ image_id = image["id"]
+ relative_dest = imageid2relativedest[image_id]
+
+ dataset_index.append(relative_dest)
+
+ raw_data = np.zeros((len(dataset_total_images), len(columnindex))) * np.nan
+
+ dataset_index = dataset_index
+
+ df = pd.DataFrame(raw_data, columns=columnindex, index=dataset_index)
+
+ for idx, anno in enumerate(dataset_total_annotations):
+ keypoints = np.array(anno["keypoints"])
+ image_id = anno["image_id"]
+
+ file_name = imageid2relativedest[image_id]
+
+ for kpt_id, kpt_name in enumerate(meta["categories"]["keypoints"]):
+ coord = keypoints[3 * kpt_id : 3 * kpt_id + 3]
+ # note dlc does not yet have visibility flag
+ # need to be careful here to assign right keypoints to right people
+
+ if coord[0] > 0 and coord[1] > 0:
+
+ df.loc[file_name][scorer, kpt_name, "x"] = coord[0]
+ df.loc[file_name][scorer, kpt_name, "y"] = coord[1]
+ elif coord[2] == -1:
+ # if -1, this visibility flag means a given keypoint was not annotated in the original dataset
+ df.loc[file_name][scorer, kpt_name, "x"] = -1
+ df.loc[file_name][scorer, kpt_name, "y"] = -1
+
+ df = df.dropna(how="all")
+ df.to_hdf(
+ os.path.join(
+ proj_root, "labeled-data", dataset_name, f"CollectedData_{scorer}.h5"
+ ),
+ key="df_with_missing",
+ mode="w",
+ )
+
+ create_training_dataset(os.path.join(proj_root, "config.yaml"))
+
+ # dlc's merge_annotation messes up my indices, so I will need to overwrite the documentation file
+ # I could have done it in a more elegant way if I could modify part of DLC source code, but for backward compatibility reasons, overriding documentation is smarter
+
+ config_path = os.path.join(proj_root, "config.yaml")
+
+ cfg = auxiliaryfunctions.read_config(config_path)
+
+ train_folder = os.path.join(proj_root, auxiliaryfunctions.GetTrainingSetFolder(cfg))
+
+ datafilename, metafilename = auxiliaryfunctions.GetDataandMetaDataFilenames(
+ train_folder, train_fraction, 1, cfg
+ )
+
+ modify_train_test_cfg(config_path)
+
+ dlc_df = pd.read_hdf(os.path.join(train_folder, f"CollectedData_{scorer}.h5"))
+
+ parent_trace = {}
+
+ def _filter(image):
+ file_name = image["file_name"]
+ image_name = file_name.split(os.sep)[-1]
+ video_folder = file_name.split(os.sep)[-2]
+ pre, suffix = image_name.split(".")
+ image_id = image["id"]
+ if append_image_id:
+ ret = f"{pre}_{image_id}.{suffix}"
+ else:
+ ret = image_name
+
+ parent_trace[ret] = video_folder
+
+ return ret
+
+ _filter_train_images = list(map(_filter, train_images))
+ _filter_test_images = list(map(_filter, test_images))
+
+ with open(os.path.join(train_folder, "parent_trace.pickle"), "wb") as f:
+ pickle.dump(parent_trace, f)
+
+ trainIndices = [
+ idx
+ for idx, image in enumerate(dlc_df.index)
+ if get_filename(image).split(os.sep)[-1] in _filter_train_images
+ ]
+ testIndices = [
+ idx
+ for idx, image in enumerate(dlc_df.index)
+ if get_filename(image).split(os.sep)[-1] in _filter_test_images
+ ]
+
+ with open(metafilename, "rb") as f:
+ metafile = pickle.load(f)
+
+ metafile[1] = trainIndices
+ metafile[2] = testIndices
+
+ with open(metafilename, "wb") as f:
+ pickle.dump(metafile, f)
+
+ # need to overwrite the true data file too
+ nbodyparts = len(bodyparts)
+
+ data, MatlabData = format_single_training_data(
+ dlc_df, trainIndices, nbodyparts, cfg["project_path"]
+ )
+
+ print(f"overwriting data file {datafilename}")
+
+ sio.savemat(os.path.join(datafilename), {"dataset": MatlabData})
+
+
+def _generic2coco(
+ proj_root,
+ train_images,
+ test_images,
+ train_annotations,
+ test_annotations,
+ meta,
+ deepcopy: bool = False,
+ full_image_path: bool = True,
+ append_image_id: bool = True,
+ no_image_copy: bool = False,
+):
+ """
+ Take generic data and create coco structure
+ My generic definition of coco structure:
+ images
+ ...
+ annotations
+ - train.json
+ - test.json
+
+ Args:
+ deepcopy: Only when no_image_copy=False. If False, images are not copied from
+ their original location and symlinks are created instead.
+ full_image_path: Only when no_image_copy=False. If True, the ``file_name`` for
+ the images in the annotation files contain the resolved path to the images.
+ Otherwise, a relative path is used.
+ append_image_id: Only when no_image_copy=False. Appends the image IDs in the
+ dataset to the image names.
+ no_image_copy: Instead of copying images to the COCO dataset, the full paths to
+ the images in the original dataset are used in the annotations.
+ """
+
+ os.makedirs(os.path.join(proj_root, "images"), exist_ok=True)
+ os.makedirs(os.path.join(proj_root, "annotations"), exist_ok=True)
+
+ # from new path to old_path
+ lookuptable = {}
+
+ for annotation in train_annotations + test_annotations:
+ if "iscrowd" not in annotation:
+ annotation["iscrowd"] = 0
+
+ keypoints = annotation["keypoints"]
+ for kpt_id, kpt_name in enumerate(meta["categories"]["keypoints"]):
+ coord = keypoints[3 * kpt_id : 3 * kpt_id + 3]
+ if coord[0] < 0 or coord[1] < 0:
+ coord[2] = -1
+
+ broken_links = []
+ # copying images via symbolic link
+ for image in train_images + test_images:
+ # important to resolve the filepath! Otherwise, errors can occur when running
+ # this code from Jupyter Notebooks
+ src = Path(image["file_name"]).resolve()
+ image_id = image["id"]
+
+ if not src.exists():
+ print("problem comes from", image["source_dataset"])
+ print(src)
+ broken_links.append(image_id)
+ continue
+
+ file_name = str(src)
+ dest = src
+ if not no_image_copy:
+ # in dlc, some images have same name but under different folder
+ # we used to use a parent folder to distinguish them, but it's only
+ # applicable to DLC so here it's easier to append an id into the filename
+
+ # not to repeatedly add image id in memory replay training
+ dest_image_name = src.name
+ if append_image_id:
+ dest_image_name = f"{src.stem}_{image_id}{src.suffix}"
+
+ dest = Path(proj_root) / "images" / dest_image_name
+ dest = dest.resolve()
+
+ file_name = str(Path(*dest.parts[-2:]))
+ if full_image_path:
+ file_name = str(dest)
+
+ if deepcopy:
+ shutil.copy(src, dest)
+ else:
+ try:
+ os.symlink(src, dest)
+ except Exception as err:
+ print(f"Could not create a symlink from {src} to {dest}: {err}")
+ pass
+
+ image["file_name"] = file_name
+ lookuptable[dest] = src
+
+ train_annotations = [
+ train_anno
+ for train_anno in train_annotations
+ if train_anno["image_id"] not in broken_links
+ ]
+ test_annotations = [
+ test_anno
+ for test_anno in test_annotations
+ if test_anno["image_id"] not in broken_links
+ ]
+
+ with open(os.path.join(proj_root, "annotations", "train.json"), "w") as f:
+
+ train_json_obj = dict(
+ images=train_images,
+ annotations=train_annotations,
+ categories=[meta["categories"]],
+ )
+
+ json.dump(train_json_obj, f, indent=4, cls=NpEncoder)
+
+ with open(os.path.join(proj_root, "annotations", "test.json"), "w") as f:
+ test_json_obj = dict(
+ images=test_images,
+ annotations=test_annotations,
+ categories=[meta["categories"]],
+ )
+
+ json.dump(test_json_obj, f, indent=4, cls=NpEncoder)
+
+ return lookuptable
+
+
+def mat_func_factory(framework):
+ assert framework in [
+ "coco",
+ "sdlc",
+ "madlc",
+ ], f"Does not support framework {framework}"
+ if framework == "madlc":
+ mat_func = _generic2madlc
+ elif framework == "coco":
+ mat_func = _generic2coco
+ elif framework == "sdlc":
+ mat_func = _generic2sdlc
+
+ return mat_func
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/multi.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/multi.py
new file mode 100644
index 0000000000..1b24a7a2b7
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/multi.py
@@ -0,0 +1,291 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import warnings
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets.materialize import (
+ mat_func_factory,
+)
+
+
+class MultiSourceDataset:
+ """
+ Parameters:
+ iid_ood_split: {'iid' : ['dataset1', 'dataset2'],
+ 'ood' : ['dataset3', 'dataset4'] }
+
+
+
+ """
+
+ def __init__(self, dataset_name, datasets, table_path):
+ self.datasets = datasets
+ #
+ self.name2genericdataset = {}
+
+ # useful maps for analysis
+ self.imageid2filename = {}
+ self.imageid2datasetname = {}
+ self.datasetname2imageids = {}
+ #
+ self.dataset_name = dataset_name
+
+ names = []
+ for dataset in datasets:
+
+ # Must project datasets to same keypoint space before merging
+ if table_path != None:
+ dataset.project_with_conversion_table(table_path)
+ name = dataset.meta["dataset_name"]
+ names.append(name)
+ self.name2genericdataset[name] = dataset
+
+ self.meta = {}
+ self.meta["dataset_name"] = dataset_name
+ # after conversion, all datasets have same categories
+ self.meta["categories"] = dataset.meta["categories"]
+
+ # map id from local scope to global
+ self._update_imgids()
+
+ (
+ self.train_images,
+ self.test_images,
+ self.train_annotations,
+ self.test_annotations,
+ ) = self._merge_datasets(self.name2genericdataset)
+ self.meta["name2genericdataset"] = self.name2genericdataset
+
+ # only build maps after images are merged and ids are in global scope
+ self._build_maps()
+
+ def summary(self):
+ print(f"Summary of dataset {self.dataset_name}")
+ print("Decomposition of multi source datasets:")
+ for dataset_name, dataset in self.name2genericdataset.items():
+ n_images = len(dataset.generic_train_images) + len(
+ dataset.generic_test_images
+ )
+ n_annotations = len(dataset.generic_train_annotations) + len(
+ dataset.generic_test_annotations
+ )
+ print(f"{dataset_name} has {n_images} images, {n_annotations} annotations")
+
+ print(f"total train images : {len(self.train_images)}")
+ print(f"total test images : {len(self.test_images)}")
+
+ def _build_maps(self):
+
+ # shared by both scenarios
+
+ species_set = set()
+ for dataset_name, dataset in self.name2genericdataset.items():
+ # I could of course do this during merge to save compute, but doing it here makes the logic cleaner to understand
+ total_images = dataset.generic_train_images + dataset.generic_test_images
+
+ for image in total_images:
+ image_id = image["id"]
+ image_name = image["file_name"]
+ self.imageid2filename[image_id] = image_name
+
+ self.imageid2datasetname[image_id] = dataset_name
+
+ if dataset_name == "AwA-Pose":
+ species_set.add(image_name.split("/")[-1].split("_")[0])
+ self.meta["imageid2datasetname"] = self.imageid2datasetname
+
+ max_num = 0
+ for dataset_name, dataset in self.name2genericdataset.items():
+ max_num = max(max_num, dataset.meta["max_individuals"])
+ self.meta["max_individuals"] = max_num
+ dataset_name = self.meta["dataset_name"]
+ print(f"Max individual in {dataset_name} is {max_num}")
+
+ def whether_anno_image_match(self, images, annotations):
+ """
+ Every image id should be annotated at least once
+ There should not be any image that is not being annotated
+ There should not be any annotation for beyond the set of given images
+ """
+
+ image_ids = set([image["id"] for image in images])
+
+ annotation_image_ids = set([anno["image_id"] for anno in annotations])
+
+ if image_ids != annotation_image_ids:
+ print("images-annotations", image_ids - annotation_image_ids)
+ print("annotations-images", annotation_image_ids - image_ids)
+
+ warnings.warn("annotation and image ids do not match")
+
+ # This is constrain is too hard
+ # assert len(annotation_image_ids - image_ids) == 0, "You can't have annotation on non-existed images"
+
+ def _update_imgids(self):
+ """
+ update image ids for both image and annotation
+
+ If datasets are merged, their image id, annotation id will conflict because they are defined within their own local scope. Therefore, we will need to put these ids in the global scope
+
+ """
+
+ from collections import defaultdict
+
+ dataset_id_pool = defaultdict(set)
+ all_datasets = self.name2genericdataset.values()
+
+ total_number_images = 0
+ total_number_annotations = 0
+ for dataset in all_datasets:
+ total_number_images += len(dataset.generic_train_images) + len(
+ dataset.generic_test_images
+ )
+ total_number_annotations += len(dataset.generic_train_annotations) + len(
+ dataset.generic_test_annotations
+ )
+
+ global_image_id_pool = set(range(total_number_images))
+ global_annotation_id_pool = set(range(total_number_annotations))
+
+ for dataset_name, dataset in self.name2genericdataset.items():
+
+ local_image_id_map = defaultdict(int)
+ local_anno_id_map = defaultdict(int)
+
+ traintest_images = (
+ dataset.generic_train_images + dataset.generic_test_images
+ )
+ traintest_annotations = (
+ dataset.generic_train_annotations + dataset.generic_test_annotations
+ )
+
+ for img in traintest_images:
+
+ new_image_id = global_image_id_pool.pop()
+ local_image_id_map[img["id"]] = new_image_id
+ img["id"] = new_image_id
+ dataset_id_pool[dataset_name].add(img["id"])
+
+ for anno in traintest_annotations:
+ anno["image_id"] = local_image_id_map[anno["image_id"]]
+ new_anno_id = global_annotation_id_pool.pop()
+ local_anno_id_map[anno["id"]] = new_anno_id
+ anno["id"] = new_anno_id
+
+ self.whether_anno_image_match(traintest_images, traintest_annotations)
+
+ from functools import reduce
+
+ count = 0
+ for k, v in dataset_id_pool.items():
+ count += len(v)
+ print("size of the summation", count)
+ union = reduce(set.union, dataset_id_pool.values())
+ print("size of the union", len(union))
+
+ def _merge_datasets(self, name2dataset):
+ """
+ Merged datasets into common list
+
+ # only do this when iid/ood split is done
+
+ """
+
+ merged_train_images = []
+ merged_test_images = []
+ merged_train_annotations = []
+ merged_test_annotations = []
+
+ for dataset_name, dataset in name2dataset.items():
+
+ train_images = dataset.generic_train_images
+ test_images = dataset.generic_test_images
+ train_annotations = dataset.generic_train_annotations
+ test_annotations = dataset.generic_test_annotations
+
+ merged_train_images.extend(train_images)
+ merged_test_images.extend(test_images)
+ merged_train_annotations.extend(train_annotations)
+ merged_test_annotations.extend(test_annotations)
+
+ print("Checking merged dataset")
+
+ merged_traintest_images = merged_train_images + merged_test_images
+ merged_traintest_annotations = (
+ merged_train_annotations + merged_test_annotations
+ )
+
+ self.whether_anno_image_match(
+ merged_traintest_images, merged_traintest_annotations
+ )
+
+ return (
+ merged_train_images,
+ merged_test_images,
+ merged_train_annotations,
+ merged_test_annotations,
+ )
+
+ def __eq__(self, other_dataset):
+
+ if isinstance(other_dataset, BasePoseDataset):
+
+ train_images1 = set(map(raw_2_imagename_with_id, self.train_images))
+ train_images2 = set(
+ map(raw_2_imagename, other_dataset.generic_train_images)
+ )
+
+ test_images1 = set(map(raw_2_imagename_with_id, self.test_images))
+ test_images2 = set(map(raw_2_imagename, other_dataset.generic_test_images))
+ if train_images1 == train_images2 and test_images1 == test_images2:
+ print(
+ f'dataset {self.meta["dataset_name"]} and {other_dataset.meta["dataset_name"]} are equivalent'
+ )
+ return True
+ else:
+ print(
+ f'dataset {self.meta["dataset_name"]} and {other_dataset.meta["dataset_name"]} are NOT equivalent'
+ )
+ return False
+
+ else:
+ return NotImplementedError("Not existed")
+
+ def materialize(
+ self,
+ proj_root,
+ framework="coco",
+ train_all=False,
+ deepcopy=False,
+ full_image_path=True,
+ ):
+
+ # can't be set to true at the same time. This will cause bugs
+ assert sum([train_all, full_image_path]) != 2
+
+ mat_func = mat_func_factory(framework)
+
+ self.meta["mat_datasets"] = self.name2genericdataset
+
+ if train_all:
+ # for pretrian phase, we can just train everything including the test part
+ self.train_images += self.test_images
+ self.train_annotations += self.test_annotations
+
+ mat_func(
+ proj_root,
+ self.train_images,
+ self.test_images,
+ self.train_annotations,
+ self.test_annotations,
+ self.meta,
+ deepcopy=deepcopy,
+ full_image_path=full_image_path,
+ )
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc.py
new file mode 100644
index 0000000000..8d7b419654
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc.py
@@ -0,0 +1,135 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base_dlc import (
+ BaseDLCPoseDataset,
+)
+from deeplabcut.modelzoo.generalized_data_converter.datasets.utils import (
+ calc_bboxes_from_keypoints,
+ read_image_shape_fast,
+)
+
+
+class SingleDLCPoseDataset(BaseDLCPoseDataset):
+ """
+ The philosophy is to assume the dataset is already created so this class is not
+ responsible for creating training dataset
+ """
+
+ def __init__(self, proj_root, dataset_name, shuffle=1, modelprefix=""):
+ super(SingleDLCPoseDataset, self).__init__(
+ proj_root, dataset_name, shuffle=shuffle, modelprefix=modelprefix
+ )
+
+ # overriding max_individuals
+ self.meta["max_individuals"] = 1
+
+ def _df2generic(self, df, image_id_offset=0):
+
+ bpts = df.columns.get_level_values("bodyparts").unique().tolist()
+
+ coco_categories = []
+
+ # single animal only has individual0
+
+ category = {
+ "name": "individual0",
+ "id": 0,
+ "supercategory": "animal",
+ }
+
+ category["keypoints"] = bpts
+
+ coco_categories.append(category)
+
+ coco_images = []
+ coco_annotations = []
+
+ annotation_id = 0
+ image_id = -1
+
+ for _, file_name in enumerate(df.index):
+ data = df.loc[file_name]
+
+ # skipping all nan
+
+ if np.isnan(data.to_numpy()).all():
+ continue
+
+ image_id += 1
+ category_id = 0
+ kpts = data.to_numpy().reshape(-1, 2)
+ keypoints = np.zeros((len(kpts), 3))
+
+ keypoints[:, :2] = kpts
+
+ is_visible = ~pd.isnull(kpts).all(axis=1)
+
+ keypoints[:, 2] = np.where(is_visible, 2, 0)
+
+ num_keypoints = is_visible.sum()
+
+ bbox_margin = 20
+
+ xmin, ymin, xmax, ymax = calc_bboxes_from_keypoints(
+ [keypoints],
+ slack=bbox_margin,
+ clip=True,
+ )[0][:4]
+
+ w = xmax - xmin
+ h = ymax - ymin
+ area = w * h
+ bbox = np.nan_to_num([xmin, ymin, w, h])
+ keypoints = np.nan_to_num(keypoints.flatten())
+
+ annotation_id += 1
+ annotation = {
+ "image_id": image_id + image_id_offset,
+ "num_keypoints": num_keypoints,
+ "keypoints": keypoints,
+ "id": annotation_id,
+ "category_id": category_id,
+ "area": area,
+ "bbox": bbox,
+ "iscrowd": 0,
+ }
+ if np.sum(keypoints) != 0:
+
+ coco_annotations.append(annotation)
+
+ # I think width and height are important
+
+ if isinstance(file_name, tuple):
+ image_path = os.path.join(self.proj_root, *list(file_name))
+ else:
+ image_path = os.path.join(self.proj_root, file_name)
+
+ _, height, width = read_image_shape_fast(image_path)
+
+ image = {
+ "file_name": image_path,
+ "width": width,
+ "height": height,
+ "id": image_id + image_id_offset,
+ }
+ coco_images.append(image)
+
+ ret_obj = {
+ "images": coco_images,
+ "annotations": coco_annotations,
+ "categories": coco_categories,
+ }
+ return ret_obj
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc_dataframe.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc_dataframe.py
new file mode 100644
index 0000000000..e6e8fd5828
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/single_dlc_dataframe.py
@@ -0,0 +1,243 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
+ parse_video_filenames,
+)
+from deeplabcut.modelzoo.generalized_data_converter.datasets.base import BasePoseDataset
+from deeplabcut.modelzoo.generalized_data_converter.datasets.utils import (
+ calc_bboxes_from_keypoints,
+ read_image_shape_fast,
+)
+from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions, conversioncode
+
+
+def merge_annotateddatasets(cfg):
+ """
+ Merges all the h5 files for all labeled-datasets (from individual videos).
+
+ This is a bit of a mess because of cross platform compatibility.
+
+ Within platform comp. is straightforward. But if someone labels on windows and wants to train on a unix cluster or colab...
+ """
+ AnnotationData = []
+ data_path = Path(os.path.join(cfg["project_path"], "labeled-data"))
+ videos = cfg["video_sets"].keys()
+ video_filenames = parse_video_filenames(videos)
+ for filename in video_filenames:
+ file_path = os.path.join(
+ data_path / filename, f'CollectedData_{cfg["scorer"]}.h5'
+ )
+ try:
+ data = pd.read_hdf(file_path)
+ conversioncode.guarantee_multiindex_rows(data)
+ if data.columns.levels[0][0] != cfg["scorer"]:
+ print(
+ f"{file_path} labeled by a different scorer. This data will not be utilized in training dataset creation. If you need to merge datasets across scorers, see https://github.com/DeepLabCut/DeepLabCut/wiki/Using-labeled-data-in-DeepLabCut-that-was-annotated-elsewhere-(or-merge-across-labelers)"
+ )
+ continue
+ AnnotationData.append(data)
+ except FileNotFoundError:
+ print(file_path, " not found (perhaps not annotated).")
+
+ if not len(AnnotationData):
+ print(
+ "Annotation data was not found by splitting video paths (from config['video_sets']). An alternative route is taken..."
+ )
+ AnnotationData = conversioncode.merge_windowsannotationdataONlinuxsystem(cfg)
+ if not len(AnnotationData):
+ print("No data was found!")
+ return
+
+ AnnotationData = pd.concat(AnnotationData).sort_index()
+ # When concatenating DataFrames with misaligned column labels,
+ # all sorts of reordering may happen (mainly depending on 'sort' and 'join')
+ # Ensure the 'bodyparts' level agrees with the order in the config file.
+ if cfg.get("multianimalproject", False):
+ (
+ _,
+ uniquebodyparts,
+ multianimalbodyparts,
+ ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
+ bodyparts = multianimalbodyparts + uniquebodyparts
+ else:
+ bodyparts = cfg["bodyparts"]
+ AnnotationData = AnnotationData.reindex(
+ bodyparts, axis=1, level=AnnotationData.columns.names.index("bodyparts")
+ )
+
+ return AnnotationData
+
+
+class SingleDLCDataFrame(BasePoseDataset):
+
+ def __init__(self, proj_root, dataset_name):
+ super(SingleDLCDataFrame, self).__init__()
+ self.meta["max_individuals"] = 1
+ assert proj_root != None and dataset_name != None
+ self.proj_root = proj_root
+ self.dataset_name = dataset_name
+ self.meta["dataset_name"] = dataset_name
+ self.meta["proj_root"] = proj_root
+ config_path = Path(proj_root) / "config.yaml"
+ # read config
+ cfg = auxiliaryfunctions.read_config(config_path)
+ # get the train folder
+
+ Data = merge_annotateddatasets(
+ cfg,
+ )
+
+ # now with this data, we construct necessary generic data
+
+ self.dlc_df = Data
+
+ images = self.dlc_df.index
+
+ ratio = 0.9
+
+ df_train = self.dlc_df.iloc[: int(len(images) * ratio)]
+ df_test = self.dlc_df.iloc[int(len(images) * ratio) :]
+
+ self.coco_train = self._df2generic(df_train)
+
+ offset = len(self.coco_train["images"])
+
+ self.coco_test = self._df2generic(df_test, image_id_offset=offset)
+
+ self.populate_generic()
+
+ def populate_generic(self):
+
+ self.generic_train_images = self.coco_train["images"]
+ self.generic_test_images = self.coco_test["images"]
+ self.generic_train_annotations = self.coco_train["annotations"]
+ self.generic_test_annotations = self.coco_test["annotations"]
+
+ self.meta["categories"] = self.coco_test["categories"][0]
+
+ # to build maps for later analysis
+ self._build_maps()
+
+ print(f"Before checking trainset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_train_images, self.generic_train_annotations
+ )
+
+ print(f"Before checking testset {self.meta['dataset_name']}")
+
+ self.whether_anno_image_match(
+ self.generic_test_images, self.generic_test_annotations
+ )
+
+ def _df2generic(self, df, image_id_offset=0):
+
+ bpts = df.columns.get_level_values("bodyparts").unique().tolist()
+
+ coco_categories = []
+
+ # single animal only has individual0
+
+ category = {
+ "name": "individual0",
+ "id": 0,
+ "supercategory": "animal",
+ }
+
+ category["keypoints"] = bpts
+
+ coco_categories.append(category)
+
+ coco_images = []
+ coco_annotations = []
+
+ annotation_id = 0
+ image_id = -1
+
+ for _, file_name in enumerate(df.index):
+ data = df.loc[file_name]
+
+ # skipping all nan
+
+ if np.isnan(data.to_numpy()).all():
+ continue
+
+ image_id += 1
+ category_id = 0
+ kpts = data.to_numpy().reshape(-1, 2)
+ keypoints = np.zeros((len(kpts), 3))
+
+ keypoints[:, :2] = kpts
+
+ is_visible = ~pd.isnull(kpts).all(axis=1)
+
+ keypoints[:, 2] = np.where(is_visible, 2, 0)
+
+ num_keypoints = is_visible.sum()
+
+ bbox_margin = 20
+
+ xmin, ymin, xmax, ymax = calc_bboxes_from_keypoints(
+ [keypoints],
+ slack=bbox_margin,
+ clip=True,
+ )[0][:4]
+
+ w = xmax - xmin
+ h = ymax - ymin
+ area = w * h
+ bbox = np.nan_to_num([xmin, ymin, w, h])
+ keypoints = np.nan_to_num(keypoints.flatten())
+
+ annotation_id += 1
+ annotation = {
+ "image_id": image_id + image_id_offset,
+ "num_keypoints": num_keypoints,
+ "keypoints": keypoints,
+ "id": annotation_id,
+ "category_id": category_id,
+ "area": area,
+ "bbox": bbox,
+ "iscrowd": 0,
+ }
+ if np.sum(keypoints) != 0:
+
+ coco_annotations.append(annotation)
+
+ # I think width and height are important
+
+ if isinstance(file_name, tuple):
+ image_path = os.path.join(self.proj_root, *list(file_name))
+ else:
+ image_path = os.path.join(self.proj_root, file_name)
+
+ _, height, width = read_image_shape_fast(image_path)
+
+ image = {
+ "file_name": image_path,
+ "width": width,
+ "height": height,
+ "id": image_id + image_id_offset,
+ }
+ coco_images.append(image)
+
+ ret_obj = {
+ "images": coco_images,
+ "annotations": coco_annotations,
+ "categories": coco_categories,
+ }
+ return ret_obj
diff --git a/deeplabcut/modelzoo/generalized_data_converter/datasets/utils.py b/deeplabcut/modelzoo/generalized_data_converter/datasets/utils.py
new file mode 100644
index 0000000000..d04e92201a
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/datasets/utils.py
@@ -0,0 +1,40 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from functools import lru_cache
+
+import numpy as np
+from PIL import Image
+
+
+def calc_bboxes_from_keypoints(data, slack=0, offset=0, clip=False):
+ data = np.asarray(data)
+ if data.shape[-1] < 3:
+ raise ValueError("Data should be of shape (n_animals, n_bodyparts, 3)")
+
+ if data.ndim != 3:
+ data = np.expand_dims(data, axis=0)
+ bboxes = np.full((data.shape[0], 5), np.nan)
+ bboxes[:, :2] = np.nanmin(data[..., :2], axis=1) - slack # X1, Y1
+ bboxes[:, 2:4] = np.nanmax(data[..., :2], axis=1) + slack # X2, Y2
+ bboxes[:, -1] = np.nanmean(data[..., 2]) # Average confidence
+ bboxes[:, [0, 2]] += offset
+ if clip:
+ coord = bboxes[:, :4]
+ coord[coord < 0] = 0
+ return bboxes
+
+
+@lru_cache(maxsize=None)
+def read_image_shape_fast(path):
+ # Blazing fast and does not load the image into memory
+ with Image.open(path) as img:
+ width, height = img.size
+ return len(img.getbands()), height, width
diff --git a/deeplabcut/modelzoo/generalized_data_converter/utils.py b/deeplabcut/modelzoo/generalized_data_converter/utils.py
new file mode 100644
index 0000000000..368e0380b9
--- /dev/null
+++ b/deeplabcut/modelzoo/generalized_data_converter/utils.py
@@ -0,0 +1,324 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import glob
+import os
+import pickle
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from deeplabcut.utils import auxiliaryfunctions
+from deeplabcut.modelzoo.generalized_data_converter.datasets.materialize import (
+ SingleDLC_config,
+)
+
+
+def threshold_kpts(config_path, h5path, threshold_mean=0.9, threshold_min=0.1):
+
+ df = pd.read_hdf(h5path)
+
+ scorer = df.columns.get_level_values("scorer").unique()[0]
+ try:
+ data = df[scorer]["individual0"]
+ except:
+ data = df[scorer]
+
+ cfg = auxiliaryfunctions.read_config(config_path)
+
+ bodyparts = cfg["multianimalbodyparts"]
+
+ thresholded_bpts = []
+
+ for bpt in bodyparts:
+ _mean = data[bpt]["likelihood"].mean()
+ _min = data[bpt]["likelihood"].min()
+ _var = data[bpt]["likelihood"].var()
+ if _mean > threshold_mean and _min > threshold_min:
+ thresholded_bpts.append(bpt)
+ print(bpt, "mean", _mean)
+ print(bpt, "min", _min)
+ print(bpt, "var", _var)
+
+ print("thresholded kpts", thresholded_bpts)
+ return thresholded_bpts
+ ret = []
+ print(ret)
+ return ret
+
+
+def create_dummy_config_file_from_h5(
+ proj_root, reference_h5, taskname="dummytask", scorer="dummyscorer", date="March30"
+):
+ """
+ Assuming at least labeled-data folder is there
+ """
+
+ cfg_template = SingleDLC_config()
+
+ df = pd.read_hdf(reference_h5)
+
+ print(df)
+
+ pattern = glob.glob(os.path.join(proj_root, "labeled-data", "*"))
+
+ labeled_folders = [f.split("/")[-1] for f in pattern]
+
+ video_sets = {
+ f"{folder}.mp4": {"crop": "0, 400, 0, 400"} for folder in labeled_folders
+ }
+
+ # bodyparts = df[scorer]['bodyparts']
+
+ bodyparts = list(df.columns.get_level_values("bodyparts").unique())
+ scorer = df.columns.get_level_values("scorer").unique()[0]
+
+ modify_dict = dict(
+ Task=taskname,
+ project_path=proj_root,
+ scorer=scorer,
+ date=date,
+ video_sets=video_sets,
+ bodyparts=bodyparts,
+ TrainingFraction=[0.95],
+ )
+
+ cfg_template.create_cfg(proj_root, modify_dict)
+
+
+def create_dummy_config_file_from_pickle(
+ proj_root,
+ reference_pickle,
+ video_path,
+ taskname="dummytask",
+ scorer="dummyscorer",
+ date="March30",
+):
+ """
+ Assuming at least labeled-data folder is there
+ """
+
+ cfg_template = SingleDLC_config()
+
+ with open(reference_pickle, "rb") as f:
+
+ pickle_obj = pickle.load(f)
+
+ # bodyparts = pickle_obj['keypoint_names']
+ bodyparts = [
+ "tail",
+ "spine4",
+ "spine3",
+ "spine2",
+ "spine1",
+ "head",
+ "nose",
+ "right ear",
+ "left ear",
+ ]
+
+ video_name = video_path.split("/")[-1]
+
+ video_sets = {f"{video_path}": {"crop": "0, 400, 0, 400"}}
+
+ modify_dict = dict(
+ Task=taskname,
+ project_path=proj_root,
+ scorer=scorer,
+ date=date,
+ video_sets=video_sets,
+ bodyparts=bodyparts,
+ TrainingFraction=[0.95],
+ )
+
+ cfg_template.create_cfg(".", modify_dict)
+
+
+def create_video_h5_from_pickle(proj_root, cfg, reference_pickle, videopath):
+
+ with open(reference_pickle, "rb") as f:
+
+ pickle_obj = pickle.load(f)
+
+ # bodyparts = pickle_obj['keypoint_names']
+
+ bodyparts = [
+ "tail",
+ "spine4",
+ "spine3",
+ "spine2",
+ "spine1",
+ "head",
+ "nose",
+ "right ear",
+ "left ear",
+ ]
+
+ video_name = videopath.split("/")[-1]
+
+ video_key = f"{video_name}" # .replace('.top.ir.mp4', '')
+
+ print("video_key", video_key)
+
+ print(list(pickle_obj.keys()))
+
+ detections = pickle_obj[video_key]
+
+ nframes = len(detections)
+
+ xyz_labs = ["x", "y", "likelihood"]
+
+ scorer = cfg["scorer"]
+
+ keypoint_names = cfg["bodyparts"]
+
+ product = [[scorer], keypoint_names, xyz_labs]
+
+ names = ["scorer", "bodyparts", "coords"]
+ columnindex = pd.MultiIndex.from_product(product, names=names)
+ imagenames = [f"frame{i}" for i in range(nframes)]
+ data = np.zeros((len(imagenames), len(columnindex))) * np.nan
+ df = pd.DataFrame(data, columns=columnindex, index=imagenames)
+
+ for imagename, kpts in zip(imagenames, detections):
+
+ for kpt_id, kpt_name in enumerate(keypoint_names):
+
+ df.loc[imagename][scorer, kpt_name, "x"] = kpts[kpt_id, 0]
+ df.loc[imagename][scorer, kpt_name, "y"] = kpts[kpt_id, 1]
+ df.loc[imagename][scorer, kpt_name, "likelihood"] = kpts[kpt_id, 2]
+
+ vname = Path(videopath).stem
+ DLCscorer = ""
+
+ coords = [0, 400, 0, 400]
+ trainFraction = cfg["TrainingFraction"][0]
+ modelfolder = os.path.join(
+ cfg["project_path"],
+ str(auxiliaryfunctions.get_model_folder(trainFraction, 0, cfg)),
+ )
+
+ path_test_config = Path(modelfolder) / "test" / "pose_cfg.yaml"
+ test_cfg = auxiliaryfunctions.read_plainconfig(path_test_config)
+
+ start = 0
+ stop = 10
+ fps = 10
+ dictionary = {
+ "start": start,
+ "stop": stop,
+ "run_duration": stop - start,
+ "Scorer": DLCscorer,
+ "DLC-model-config file": test_cfg,
+ "fps": fps,
+ "batch_size": test_cfg["batch_size"],
+ "frame_dimensions": (400, 400),
+ "nframes": nframes,
+ "iteration (active-learning)": cfg["iteration"],
+ "cropping": cfg["cropping"],
+ "training set fraction": trainFraction,
+ "cropping_parameters": coords,
+ }
+ metadata = {"data": dictionary}
+
+ dataname = os.path.join(proj_root, vname + DLCscorer + ".h5")
+
+ metadata_path = dataname.split(".h5")[0] + "_meta.pickle"
+
+ with open(metadata_path, "wb") as f:
+ pickle.dump(metadata, f, pickle.HIGHEST_PROTOCOL)
+
+ df.to_hdf(dataname, "df_with_missing", format="table", mode="w")
+
+
+def add_skeleton(config_path, pretrain_model_name):
+
+ modelzoo_names = ["superquadruped", "supertopview"]
+
+ assert pretrain_model_name in modelzoo_names
+
+ super_quadruped = [
+ ("left_eye", "right_eye"),
+ ("left_eye", "left_earbase"),
+ ("right_eye", "right_earbase"),
+ ("left_eye", "nose"),
+ ("right_eye", "nose"),
+ ("nose", "throat_base"),
+ ("throat_base", "back_base"),
+ ("tail_base", "back_base"),
+ ("throat_base", "front_left_thai"),
+ ("front_left_thai", "front_left_knee"),
+ ("front_left_knee", "front_left_paw"),
+ ("throat_base", "front_right_thai"),
+ ("front_right_thai", "front_right_knee"),
+ ("front_right_knee", "front_right_paw"),
+ ("tail_base", "back_left_thai"),
+ ("back_left_thai", "back_left_knee"),
+ ("back_left_knee", "back_left_paw"),
+ ("tail_base", "back_right_thai"),
+ ("back_right_thai", "back_right_knee"),
+ ("back_right_knee", "back_right_paw"),
+ ]
+
+ skeleton_dict = {"superquadruped": super_quadruped, "supertopview": None}
+
+ skeleton = skeleton_dict[pretrain_model_name]
+
+ cfg = auxiliaryfunctions.read_config(config_path)
+ cfg["skeleton"] = skeleton
+ print(f"overwriting skeleton for {config_path}")
+ auxiliaryfunctions.write_config(config_path, cfg)
+
+
+def customized_colormap(config_path):
+ # look for all symmetric keypoints
+ # make symmetric keypoints the same color
+
+ cfg = auxiliaryfunctions.read_config(config_path)
+ bodyparts = cfg["multianimalbodyparts"]
+ n_bodyparts = len(cfg["multianimalbodyparts"])
+
+ import matplotlib.pyplot as plt
+
+ cmap = plt.cm.get_cmap("rainbow", n_bodyparts)
+
+ colors = [cmap(i) for i in range(n_bodyparts)]
+
+ visited = set()
+ for kpt_id in range(len(bodyparts)):
+
+ bodypart = bodyparts[kpt_id]
+ if "left" in bodypart:
+ ref_color = colors[kpt_id]
+ temp = bodypart.replace("left", "right")
+ if temp in bodyparts:
+ temp_id = bodyparts.index(temp)
+ colors[temp_id] = ref_color
+
+ def ret_function(i):
+ return colors[i]
+
+ return ret_function
+
+
+def create_modelprefix(modelprefix):
+ import shutil
+
+ shutil.copytree(
+ "template-dlc-models",
+ os.path.join(modelprefix, "dlc-models"),
+ dirs_exist_ok=True,
+ )
+
+
+if __name__ == "__main__":
+
+ customized_colormap("hei")
diff --git a/deeplabcut/pose_estimation_tensorflow/superanimal_configs/supertopview.yaml b/deeplabcut/modelzoo/model_configs/dlcrnet.yaml
similarity index 53%
rename from deeplabcut/pose_estimation_tensorflow/superanimal_configs/supertopview.yaml
rename to deeplabcut/modelzoo/model_configs/dlcrnet.yaml
index a0b2da3064..570c1a0aaa 100644
--- a/deeplabcut/pose_estimation_tensorflow/superanimal_configs/supertopview.yaml
+++ b/deeplabcut/modelzoo/model_configs/dlcrnet.yaml
@@ -1,62 +1,51 @@
-all_joints:
-- - 0
-- - 1
-- - 2
-- - 3
-- - 4
-- - 5
-- - 6
-- - 7
-- - 8
-- - 9
-- - 10
-- - 11
-- - 12
-- - 13
-- - 14
-- - 15
-- - 16
-- - 17
-- - 18
-- - 19
-- - 20
-- - 21
-- - 22
-- - 23
-- - 24
-- - 25
-- - 26
-all_joints_names:
-- nose
-- left_ear
-- right_ear
-- left_ear_tip
-- right_ear_tip
-- left_eye
-- right_eye
-- neck
-- mid_back
-- mouse_center
-- mid_backend
-- mid_backend2
-- mid_backend3
-- tail_base
-- tail1
-- tail2
-- tail3
-- tail4
-- tail5
-- left_shoulder
-- left_midside
-- left_hip
-- right_shoulder
-- right_midside
-- right_hip
-- tail_end
-- head_midpoint
+ # Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+
+ # Project path (change when moving around)
+project_path:
+
+ # Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+
+ # Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff:
+dotsize:
+alphavalue:
+colormap:
+
+ # Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+batch_size: 1
+
+ # Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+ #if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+
+ # Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
alpha_r: 0.02
apply_prob: 0.5
-batch_size: 1
clahe: true
claheratio: 0.1
crop_sampling: hybrid
@@ -64,7 +53,7 @@ crop_size:
- 400
- 400
cropratio: 0.4
-dataset: training-datasets/iteration-0/UnaugmentedDataSet_ma_supertopviewMarch30/ma_supertopview_maDLC_scorer95shuffle1.pickle
+dataset:
dataset_type: multi-animal-imgaug
decay_steps: 30000
display_iters: 500
@@ -90,7 +79,11 @@ locref_stdev: 7.2801
lr_init: 0.0005
max_input_size: 1500
max_shift: 0.4
-metadataset: training-datasets/iteration-0/UnaugmentedDataSet_ma_supertopviewMarch30/Documentation_data-ma_supertopview_95shuffle1.pickle
+mean_pixel:
+- 123.68
+- 116.779
+- 103.939
+metadataset:
min_input_size: 64
mirror: false
multi_stage: true
@@ -114,7 +107,6 @@ partaffinityfield_graph: []
partaffinityfield_predict: false
pos_dist_thresh: 17
pre_resize: []
-project_path:
rotation: 25
rotratio: 0.4
save_iters: 10000
@@ -122,5 +114,8 @@ scale_jitter_lo: 0.5
scale_jitter_up: 1.25
sharpen: false
sharpenratio: 0.3
+stride: 8.0
weigh_only_present_joints: false
gradient_masking: true
+weight_decay: 0.0001
+weigh_part_predictions: false
diff --git a/deeplabcut/modelzoo/model_configs/fasterrcnn_mobilenet_v3_large_fpn.yaml b/deeplabcut/modelzoo/model_configs/fasterrcnn_mobilenet_v3_large_fpn.yaml
new file mode 100644
index 0000000000..6d4e11b55a
--- /dev/null
+++ b/deeplabcut/modelzoo/model_configs/fasterrcnn_mobilenet_v3_large_fpn.yaml
@@ -0,0 +1,51 @@
+data:
+ colormode: RGB
+ inference:
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [ 1.0, 1.0 ]
+ translation: 40
+ collate:
+ type: ResizeFromDataSizeCollate
+ min_scale: 0.4
+ max_scale: 1.0
+ min_short_side: 128
+ max_short_side: 1152
+ multiple_of: 32
+ to_square: false
+ hflip: true
+ normalize_images: true
+device: auto
+model:
+ type: FasterRCNN
+ variant: fasterrcnn_mobilenet_v3_large_fpn
+ box_score_thresh: 0.6
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+runner:
+ type: DetectorTrainingRunner
+ key_metric: "test.mAP@50:95"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-5
+ scheduler:
+ type: LRListScheduler
+ params:
+ milestones: [ 90 ]
+ lr_list: [ [ 1e-6 ] ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 1
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 250
diff --git a/deeplabcut/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml b/deeplabcut/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml
new file mode 100644
index 0000000000..27d147e339
--- /dev/null
+++ b/deeplabcut/modelzoo/model_configs/fasterrcnn_resnet50_fpn_v2.yaml
@@ -0,0 +1,51 @@
+data:
+ colormode: RGB
+ inference:
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [ 1.0, 1.0 ]
+ translation: 40
+ collate:
+ type: ResizeFromDataSizeCollate
+ min_scale: 0.4
+ max_scale: 1.0
+ min_short_side: 128
+ max_short_side: 1152
+ multiple_of: 32
+ to_square: false
+ hflip: true
+ normalize_images: true
+device: auto
+model:
+ type: FasterRCNN
+ variant: fasterrcnn_resnet50_fpn_v2
+ box_score_thresh: 0.6
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+runner:
+ type: DetectorTrainingRunner
+ key_metric: "test.mAP@50:95"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-5
+ scheduler:
+ type: LRListScheduler
+ params:
+ milestones: [ 90 ]
+ lr_list: [ [ 1e-6 ] ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 1
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 250
\ No newline at end of file
diff --git a/deeplabcut/modelzoo/model_configs/hrnet_w32.yaml b/deeplabcut/modelzoo/model_configs/hrnet_w32.yaml
new file mode 100644
index 0000000000..011d7727c3
--- /dev/null
+++ b/deeplabcut/modelzoo/model_configs/hrnet_w32.yaml
@@ -0,0 +1,81 @@
+data:
+ colormode: RGB
+ inference:
+ auto_padding:
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ scaling: [1.0, 1.0]
+ rotation: 30
+ translation: 0
+ gaussian_noise: 12.75
+ normalize_images: true
+ auto_padding:
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+device: auto
+method: td
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w32
+ pretrained: false
+ freeze_bn_stats: True
+ freeze_bn_weights: False
+ interpolate_branches: false
+ increased_channel_count: false
+ backbone_output_channels: 32
+ heads:
+ bodypart:
+ type: HeatmapHead
+ weight_init: "normal"
+ predictor:
+ type: HeatmapPredictor
+ apply_sigmoid: false
+ clip_scores: true
+ location_refinement: false
+ locref_std: 7.2801
+ target_generator:
+ type: HeatmapGaussianGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ generate_locref: false
+ locref_std: 7.2801
+ criterion:
+ heatmap:
+ type: WeightedMSECriterion
+ weight: 1.0
+ heatmap_config:
+ channels: [32, "num_bodyparts"]
+ kernel_size: [1]
+ strides: [1]
+net_type: hrnet_w32
+runner:
+ type: PoseTrainingRunner
+ key_metric: "test.mAP"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-5
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-6 ], [ 1e-7 ] ]
+ milestones: [ 160, 190 ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 1
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 200
+ seed: 42
diff --git a/deeplabcut/modelzoo/model_configs/resnet_50.yaml b/deeplabcut/modelzoo/model_configs/resnet_50.yaml
new file mode 100644
index 0000000000..994840c7d6
--- /dev/null
+++ b/deeplabcut/modelzoo/model_configs/resnet_50.yaml
@@ -0,0 +1,88 @@
+data:
+ colormode: RGB
+ inference:
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ scaling: [1.0, 1.0]
+ rotation: 30
+ translation: 0
+ gaussian_noise: 12.75
+ normalize_images: true
+device: auto
+method: td
+model:
+ backbone:
+ type: ResNet
+ model_name: resnet50_gn
+ output_stride: 16
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ backbone_output_channels: 2048
+ heads:
+ bodypart:
+ type: HeatmapHead
+ weight_init: normal
+ predictor:
+ type: HeatmapPredictor
+ apply_sigmoid: false
+ clip_scores: true
+ location_refinement: true
+ locref_std: 7.2801
+ target_generator:
+ type: HeatmapGaussianGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ generate_locref: true
+ locref_std: 7.2801
+ criterion:
+ heatmap:
+ type: WeightedMSECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+ heatmap_config:
+ channels:
+ - 2048
+ - "num_bodyparts"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+ locref_config:
+ channels:
+ - 2048
+ - "num_bodyparts x 2"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+net_type: resnet_50
+runner:
+ type: PoseTrainingRunner
+ key_metric: "test.mAP"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-5
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-6 ], [ 1e-7 ] ]
+ milestones: [ 160, 190 ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 1
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 100
+ seed: 42
diff --git a/deeplabcut/modelzoo/model_configs/ssdlite.yaml b/deeplabcut/modelzoo/model_configs/ssdlite.yaml
new file mode 100644
index 0000000000..04e694fa0a
--- /dev/null
+++ b/deeplabcut/modelzoo/model_configs/ssdlite.yaml
@@ -0,0 +1,50 @@
+data:
+ colormode: RGB
+ inference:
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [ 1.0, 1.0 ]
+ translation: 40
+ collate:
+ type: ResizeFromDataSizeCollate
+ min_scale: 0.4
+ max_scale: 1.0
+ min_short_side: 128
+ max_short_side: 1152
+ multiple_of: 32
+ to_square: false
+ hflip: true
+ normalize_images: true
+device: auto
+model:
+ type: SSDLite
+ box_score_thresh: 0.6
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+runner:
+ type: DetectorTrainingRunner
+ key_metric: "test.mAP@50:95"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-5
+ scheduler:
+ type: LRListScheduler
+ params:
+ milestones: [ 90 ]
+ lr_list: [ [ 1e-6 ] ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 8
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 250
\ No newline at end of file
diff --git a/deeplabcut/modelzoo/models.json b/deeplabcut/modelzoo/models.json
deleted file mode 100644
index f5e6ab25f9..0000000000
--- a/deeplabcut/modelzoo/models.json
+++ /dev/null
@@ -1,4 +0,0 @@
-{
- "superanimal_quadruped": "superquadruped.yaml",
- "superanimal_topviewmouse": "supertopview.yaml"
-}
diff --git a/deeplabcut/modelzoo/models_to_framework.json b/deeplabcut/modelzoo/models_to_framework.json
new file mode 100644
index 0000000000..6ff6f969f4
--- /dev/null
+++ b/deeplabcut/modelzoo/models_to_framework.json
@@ -0,0 +1,5 @@
+{
+ "dlcrnet": "tensorflow",
+ "hrnet_w32": "pytorch",
+ "resnet_50": "pytorch"
+}
diff --git a/deeplabcut/modelzoo/project_configs/superanimal_bird.yaml b/deeplabcut/modelzoo/project_configs/superanimal_bird.yaml
new file mode 100644
index 0000000000..3144433a14
--- /dev/null
+++ b/deeplabcut/modelzoo/project_configs/superanimal_bird.yaml
@@ -0,0 +1,105 @@
+# Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+
+
+# Project path (change when moving around)
+project_path:
+
+
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+
+
+# Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+- back
+- bill
+- belly
+- breast
+- crown
+- forehead
+- left_eye
+- left_leg
+- left_wing_tip
+- left_wrist
+- nape
+- right_eye
+- right_leg
+- right_wing_tip
+- right_wrist
+- tail_tip
+- throat
+- neck
+- tail_left
+- tail_right
+- upper_spine
+- upper_half_spine
+- lower_half_spine
+- right_foot
+- left_foot
+- left_half_chest
+- right_half_chest
+- chin
+- left_tibia
+- right_tibia
+- lower_spine
+- upper_half_neck
+- lower_half_neck
+- left_chest
+- right_chest
+- upper_neck
+- left_wing_shoulder
+- left_wing_elbow
+- right_wing_shoulder
+- right_wing_elbow
+- upper_cere
+- lower_cere
+
+
+# Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+
+
+# Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff:
+dotsize:
+alphavalue:
+colormap: rainbow
+
+
+# Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+detector_snapshotindex: -1
+batch_size: 1
+detector_batch_size: 1
+
+
+# Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+#if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+
+
+# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+
+
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
diff --git a/deeplabcut/modelzoo/project_configs/superanimal_quadruped.yaml b/deeplabcut/modelzoo/project_configs/superanimal_quadruped.yaml
new file mode 100644
index 0000000000..e06f82032f
--- /dev/null
+++ b/deeplabcut/modelzoo/project_configs/superanimal_quadruped.yaml
@@ -0,0 +1,102 @@
+# Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+
+
+# Project path (change when moving around)
+project_path:
+
+
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+
+
+# Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+- nose
+- upper_jaw
+- lower_jaw
+- mouth_end_right
+- mouth_end_left
+- right_eye
+- right_earbase
+- right_earend
+- right_antler_base
+- right_antler_end
+- left_eye
+- left_earbase
+- left_earend
+- left_antler_base
+- left_antler_end
+- neck_base
+- neck_end
+- throat_base
+- throat_end
+- back_base
+- back_end
+- back_middle
+- tail_base
+- tail_end
+- front_left_thai
+- front_left_knee
+- front_left_paw
+- front_right_thai
+- front_right_knee
+- front_right_paw
+- back_left_paw
+- back_left_thai
+- back_right_thai
+- back_left_knee
+- back_right_knee
+- back_right_paw
+- belly_bottom
+- body_middle_right
+- body_middle_left
+
+
+# Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+
+
+# Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff:
+dotsize:
+alphavalue:
+colormap: rainbow
+
+
+# Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+detector_snapshotindex: -1
+batch_size: 1
+detector_batch_size: 1
+
+
+# Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+#if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+
+
+# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+
+
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
diff --git a/deeplabcut/modelzoo/project_configs/superanimal_topviewmouse.yaml b/deeplabcut/modelzoo/project_configs/superanimal_topviewmouse.yaml
new file mode 100644
index 0000000000..bd09628322
--- /dev/null
+++ b/deeplabcut/modelzoo/project_configs/superanimal_topviewmouse.yaml
@@ -0,0 +1,90 @@
+# Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+
+
+# Project path (change when moving around)
+project_path:
+
+
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+
+
+# Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+- nose
+- left_ear
+- right_ear
+- left_ear_tip
+- right_ear_tip
+- left_eye
+- right_eye
+- neck
+- mid_back
+- mouse_center
+- mid_backend
+- mid_backend2
+- mid_backend3
+- tail_base
+- tail1
+- tail2
+- tail3
+- tail4
+- tail5
+- left_shoulder
+- left_midside
+- left_hip
+- right_shoulder
+- right_midside
+- right_hip
+- tail_end
+- head_midpoint
+
+
+# Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+
+
+# Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff:
+dotsize:
+alphavalue:
+colormap: rainbow
+
+
+# Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+detector_snapshotindex: -1
+batch_size: 1
+detector_batch_size: 1
+
+
+# Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+#if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+
+
+# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+
+
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
diff --git a/deeplabcut/modelzoo/utils.py b/deeplabcut/modelzoo/utils.py
index adc78170d7..20a45a70aa 100644
--- a/deeplabcut/modelzoo/utils.py
+++ b/deeplabcut/modelzoo/utils.py
@@ -4,18 +4,375 @@
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
-import json
+from __future__ import annotations
+
import os
+import warnings
+from glob import glob
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+from matplotlib.colors import ListedColormap
+
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.core.conversion_table import ConversionTable
+from deeplabcut.utils.auxiliaryfunctions import (
+ get_bodyparts,
+ get_deeplabcut_path,
+ read_config,
+ write_config,
+)
+
+
+def dlc_modelzoo_path() -> Path:
+ """Returns: the path to the `modelzoo` folder in the DeepLabCut installation"""
+ dlc_root_path = Path(get_deeplabcut_path())
+ return dlc_root_path / "modelzoo"
+
+
+def get_super_animal_project_cfg(super_animal: str) -> dict:
+ """Gets the project configuration file for a SuperAnimal model
+
+ Args:
+ super_animal: the name of the SuperAnimal model for which to load the project
+ configuration
+
+ Returns:
+ the project configuration for the given SuperAnimal model
+
+ Raises:
+ ValueError if no such SuperAnimal is found
+ """
+ project_configs_dir = dlc_modelzoo_path() / "project_configs"
+ super_animal_projects = {p.stem: p for p in project_configs_dir.iterdir()}
+ if super_animal not in super_animal_projects:
+ raise ValueError(
+ f"No such SuperAnimal model: {super_animal}. Available SuperAnimal models "
+ f"are {', '.join(super_animal_projects.keys())}."
+ )
+
+ return read_config_as_dict(super_animal_projects[super_animal])
+
+
+def get_super_animal_scorer(
+ super_animal: str,
+ model_snapshot_path: Path,
+ detector_snapshot_path: Path | None,
+) -> str:
+ """
+ Args:
+ super_animal: The SuperAnimal dataset on which the models were trained
+ model_snapshot_path: The path for the SuperAnimal pose model snapshot
+ detector_snapshot_path: The path for the SuperAnimal detector snapshot, if a
+ detector is being used.
+
+ Returns:
+ The DLC scorer name to use for the given SuperAnimal models.
+ """
+ super_animal_prefix = super_animal + "_"
+ dlc_scorer = super_animal_prefix
+
+ if detector_snapshot_path is not None:
+ detector_name = detector_snapshot_path.stem
+ if detector_name.startswith(super_animal_prefix):
+ detector_name = detector_name[len(super_animal_prefix) :]
+ dlc_scorer += f"{detector_name}_"
+
+ model_name = model_snapshot_path.stem
+ if model_name.startswith(super_animal_prefix):
+ model_name = model_name[len(super_animal_prefix) :]
+ dlc_scorer += f"{model_name}"
+
+ return dlc_scorer
+
+
+def create_conversion_table(
+ config: str | Path,
+ super_animal: str,
+ project_to_super_animal: dict[str, str],
+) -> ConversionTable:
+ """
+ Creates a conversion table mapping bodyparts defined for a DeepLabCut project
+ to bodyparts defined for a SuperAnimal model. This allows to fine-tune SuperAnimal
+ weights instead of transfer learning from ImageNet. The conversion table is directly
+ added to the project's configuration file.
+
+ Args:
+ config: The path to the project configuration for which the conversion table
+ should be created.
+ super_animal: The SuperAnimal model for the conversion table
+ project_to_super_animal: The conversion table mapping each project bodypart
+ to the corresponding SuperAnimal bodypart.
+
+ Returns:
+ The conversion table that was added to the project config.
+
+ Raises:
+ ValueError: If the conversion table is misconfigured (e.g., if there are
+ misnamed bodyparts in the table). See ConversionTable for more.
+ """
+ cfg = read_config(str(config))
+ sa_cfg = get_super_animal_project_cfg(super_animal)
+ conversion_table = ConversionTable(
+ super_animal=super_animal,
+ project_bodyparts=get_bodyparts(cfg),
+ super_animal_bodyparts=sa_cfg["bodyparts"],
+ table=project_to_super_animal,
+ )
+
+ conversion_tables = cfg.get("SuperAnimalConversionTables")
+ if conversion_tables is None:
+ conversion_tables = {}
+
+ conversion_tables[super_animal] = conversion_table.table
+ cfg["SuperAnimalConversionTables"] = conversion_tables
+ write_config(str(config), cfg)
+ return conversion_table
+
+
+def get_conversion_table(cfg: dict | str | Path, super_animal: str) -> ConversionTable:
+ """Gets the conversion table from a project to a SuperAnimal model
+
+ Args:
+ cfg: The path to a project configuration file, or directly the project config.
+ super_animal: The SuperAnimal for which to get the configuration file.
+
+ Returns:
+ A dictionary mapping {project_bodypart: super_animal_bodypart}
+
+ Raises:
+ ValueError: If the conversion table is misconfigured (e.g., if there are
+ misnamed bodyparts in the table). See ConversionTable for more.
+ """
+ if isinstance(cfg, (str, Path)):
+ cfg = read_config(str(cfg))
+
+ conversion_tables = cfg.get("SuperAnimalConversionTables", {})
+ if conversion_tables is None or super_animal not in conversion_tables:
+ raise ValueError(
+ f"No conversion table defined in the project config for {super_animal}."
+ "Call deeplabcut.modelzoo.create_conversion_table to create one."
+ )
+
+ sa_cfg = get_super_animal_project_cfg(super_animal)
+ conversion_table = ConversionTable(
+ super_animal=super_animal,
+ project_bodyparts=get_bodyparts(cfg),
+ super_animal_bodyparts=sa_cfg["bodyparts"],
+ table=conversion_tables[super_animal],
+ )
+ return conversion_table
+
+
+def read_conversion_table_from_csv(csv_path):
+ df = pd.read_csv(csv_path, skiprows=1, header=None)
+ df = df.dropna()
+ df[0] = df[0].str.replace(r"\s+", "", regex=True)
+ df[1] = df[1].str.replace(r"\s+", "", regex=True)
+ _map = dict(zip(df[0], df[1]))
+ return _map
+
+
+def parse_project_model_name(superanimal_name: str) -> tuple[str, str]:
+ """Parses model zoo model names for SuperAnimal models
+
+ Args:
+ superanimal_name: the name of the SuperAnimal model name to parse
+
+ Returns:
+ project_name: the parsed SuperAnimal model name
+ model_name: the model architecture (e.g., dlcrnet, hrnetw32)
+ """
+
+ if superanimal_name == "superanimal_quadruped":
+ warnings.warn(
+ f"{superanimal_name} is deprecated and will be removed in a future version. Use {superanimal_name}_model_suffix instead.",
+ DeprecationWarning,
+ )
+ superanimal_name = "superanimal_quadruped_hrnetw32"
+
+ if superanimal_name == "superanimal_topviewmouse":
+ warnings.warn(
+ f"{superanimal_name} is deprecated and will be removed in a future version. Use {superanimal_name}_model_suffix instead.",
+ DeprecationWarning,
+ )
+ superanimal_name = "superanimal_topviewmouse_dlcrnet"
+
+ model_name = superanimal_name.split("_")[-1]
+ project_name = superanimal_name.replace(f"_{model_name}", "")
+
+ dlc_root_path = get_deeplabcut_path()
+ modelzoo_path = os.path.join(dlc_root_path, "modelzoo")
+
+ available_model_configs = glob(
+ os.path.join(modelzoo_path, "model_configs", "*.yaml")
+ )
+ available_models = [
+ os.path.splitext(os.path.basename(path))[0] for path in available_model_configs
+ ]
+
+ if model_name not in available_models:
+ raise ValueError(
+ f"Model {model_name} not found. Available models are: {available_models}"
+ )
+
+ available_project_configs = glob(
+ os.path.join(modelzoo_path, "project_configs", "*.yaml")
+ )
+ available_projects = [
+ os.path.splitext(os.path.basename(path))[0]
+ for path in available_project_configs
+ ]
+
+ return project_name, model_name
-def parse_available_supermodels():
- import deeplabcut
+def get_superanimal_colormaps():
+ # FIXME(shaokai) - Add colormaps for the SuperBird dataset
+ superanimal_bird_colors = (
+ np.array(
+ [
+ (127, 0, 255),
+ (115, 18, 254),
+ (103, 37, 254),
+ (91, 56, 253),
+ (79, 74, 252),
+ (65, 95, 250),
+ (53, 112, 248),
+ (41, 128, 246),
+ (29, 144, 243),
+ (15, 162, 239),
+ (3, 176, 236),
+ (8, 189, 232),
+ (20, 201, 228),
+ (34, 214, 223),
+ (46, 223, 219),
+ (58, 232, 214),
+ (70, 239, 209),
+ (84, 246, 202),
+ (96, 250, 196),
+ (108, 253, 190),
+ (120, 254, 184),
+ (134, 254, 176),
+ (146, 253, 169),
+ (158, 250, 162),
+ (170, 246, 154),
+ (184, 239, 146),
+ (196, 232, 138),
+ (208, 223, 130),
+ (220, 214, 122),
+ (234, 201, 112),
+ (246, 189, 103),
+ (255, 176, 95),
+ (255, 162, 86),
+ (255, 144, 75),
+ (255, 128, 66),
+ (255, 112, 57),
+ (255, 95, 48),
+ (255, 74, 37),
+ (255, 56, 28),
+ (255, 37, 18),
+ (255, 18, 9),
+ (255, 0, 0),
+ ]
+ )
+ / 255
+ )
+ superanimal_topviewmouse_colors = (
+ np.array(
+ [
+ [127, 0, 255],
+ [109, 28, 254],
+ [91, 56, 253],
+ [71, 86, 251],
+ [53, 112, 248],
+ [33, 139, 244],
+ [15, 162, 239],
+ [4, 185, 234],
+ [22, 203, 228],
+ [42, 220, 220],
+ [60, 233, 213],
+ [80, 244, 204],
+ [98, 250, 195],
+ [118, 254, 185],
+ [136, 254, 175],
+ [156, 250, 163],
+ [174, 244, 152],
+ [194, 233, 139],
+ [212, 220, 127],
+ [232, 203, 113],
+ [250, 185, 100],
+ [255, 162, 86],
+ [255, 139, 72],
+ [255, 112, 57],
+ [255, 86, 43],
+ [255, 56, 28],
+ [255, 28, 14],
+ ]
+ )
+ / 255
+ )
+ superanimal_quadruped_colors = (
+ np.array(
+ [
+ [255.0, 0.0, 0.0],
+ [255.0, 39.63408568671726, 0.0],
+ [255.0, 79.26817137343453, 0.0],
+ [255.0, 118.9022570601518, 0.0],
+ [255.0, 158.53634274686905, 0.0],
+ [255.0, 198.17042843358632, 0.0],
+ [255.0, 237.8045141203036, 0.0],
+ [232.56140019297916, 255.0, 0.0],
+ [192.92731450626187, 255.0, 0.0],
+ [153.2932288195446, 255.0, 0.0],
+ [113.65914313282731, 255.0, 0.0],
+ [74.02505744611004, 255.0, 0.0],
+ [34.390971759392784, 255.0, 0.0],
+ [3.5647953575585385, 255.0, 8.807909284882923],
+ [0.0, 255.0, 44.87701729490043],
+ [0.0, 255.0, 84.51085328820125],
+ [0.0, 255.0, 124.14468928150207],
+ [0.0, 255.0, 163.77852527480275],
+ [0.0, 255.0, 203.4123612681037],
+ [0.0, 255.0, 243.04619726140453],
+ [0, 220, 255],
+ [0, 255, 255],
+ [0, 165, 255],
+ [0, 150, 255],
+ [0.0, 68.78344961404169, 255.0],
+ [0.0, 29.14936392732455, 255.0],
+ [10.484721759392611, 0.0, 255.0],
+ [50.11880744611004, 0.0, 255.0],
+ [89.75289313282732, 0.0, 255.0],
+ [129.38697881954448, 0.0, 255.0],
+ [169.02106450626192, 0.0, 255.0],
+ [169.02106450626192, 0.0, 255.0],
+ [255.0, 0.0, 142.80850706015173],
+ [169.02106450626192, 0.0, 255.0],
+ [255.0, 0.0, 142.80850706015173],
+ [255.0, 0.0, 142.80850706015173],
+ [255.0, 0.0, 103.17442137343447],
+ [255.0, 0.0, 63.54033568671722],
+ [255.0, 0.0, 23.90625],
+ ]
+ )
+ / 255
+ )
- dlc_path = deeplabcut.utils.auxiliaryfunctions.get_deeplabcut_path()
- json_path = os.path.join(dlc_path, "modelzoo", "models.json")
- with open(json_path) as file:
- return json.load(file)
+ superanimal_colormaps = {
+ "superanimal_bird": ListedColormap(
+ list(superanimal_bird_colors), name="superanimal_bird"
+ ),
+ "superanimal_topviewmouse": ListedColormap(
+ list(superanimal_topviewmouse_colors), name="superanimal_topviewmouse"
+ ),
+ "superanimal_quadruped": ListedColormap(
+ list(superanimal_quadruped_colors), name="superanimal_quadruped"
+ ),
+ }
+ return superanimal_colormaps
diff --git a/deeplabcut/modelzoo/video_inference.py b/deeplabcut/modelzoo/video_inference.py
new file mode 100644
index 0000000000..f28d4846ca
--- /dev/null
+++ b/deeplabcut/modelzoo/video_inference.py
@@ -0,0 +1,507 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import json
+import os
+from pathlib import Path
+from typing import Optional, Union
+
+from dlclibrary.dlcmodelzoo.modelzoo_download import download_huggingface_model
+from ruamel.yaml import YAML
+
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.modelzoo.utils import get_super_animal_scorer
+from deeplabcut.pose_estimation_pytorch.modelzoo.train_from_coco import adaptation_train
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ get_snapshot_folder_path,
+ get_super_animal_snapshot_path,
+ load_super_animal_config,
+ update_config,
+)
+from deeplabcut.utils.auxiliaryfunctions import get_deeplabcut_path
+from deeplabcut.utils.pseudo_label import (
+ dlc3predictions_2_annotation_from_video,
+ video_to_frames,
+)
+
+
+def video_inference_superanimal(
+ videos: Union[str, list],
+ superanimal_name: str,
+ model_name: str,
+ detector_name: str | None = None,
+ scale_list: Optional[list] = None,
+ videotype: str = ".mp4",
+ dest_folder: Optional[str] = None,
+ cropping: list[int] | None = None,
+ video_adapt: bool = False,
+ plot_trajectories: bool = False,
+ batch_size: int = 1,
+ detector_batch_size: int = 1,
+ pcutoff: float = 0.1,
+ adapt_iterations: int = 1000,
+ pseudo_threshold: float = 0.1,
+ bbox_threshold: float = 0.9,
+ detector_epochs: int = 4,
+ pose_epochs: int = 4,
+ max_individuals: int = 10,
+ video_adapt_batch_size: int = 8,
+ device: Optional[str] = "auto",
+ customized_pose_checkpoint: Optional[str] = None,
+ customized_detector_checkpoint: Optional[str] = None,
+ customized_model_config: Optional[str] = None,
+ plot_bboxes: bool = True,
+):
+ """
+ This function performs inference on videos using a pretrained SuperAnimal model.
+
+ IMPORTANT: Note that since we have both TensorFlow and PyTorch Engines, we will
+ route the engine based on the model you select:
+
+ * dlcrnet -> TensorFlow
+ * all others - > PyTorch
+
+ Parameters
+ ----------
+
+ videos (str or list):
+ The path to the video or a list of paths to videos.
+
+ superanimal_name (str):
+ The name of the SuperAnimal dataset for which to load a pre-trained model.
+
+ model_name (str):
+ The model architecture to use for inference.
+
+ detector_name (str):
+ For top-down models (only available with the PyTorch framework), the type of
+ object detector to use for inference.
+
+ scale_list (list):
+ A list of different resolutions for the spatial pyramid. Used only for bottom up models.
+
+ videotype (str):
+ Checks for the extension of the video in case the input to the video is a directory.
+ Only videos with this extension are analyzed. The default is ``.mp4``.
+
+ dest_folder (str): The path to the folder where the results should be saved.
+
+ cropping: list or None, optional, default=None
+ Only for SuperAnimal models running with the PyTorch engine.
+ List of cropping coordinates as [x1, x2, y1, y2].
+ Note that the same cropping parameters will then be used for all videos.
+ If different video crops are desired, run ``video_inference_superanimal`` on
+ individual videos with the corresponding cropping coordinates.
+
+ video_adapt (bool):
+ Whether to perform video adaptation. The default is False.
+ You only need to perform it on one video because the adaptation generalizes to all videos that are similar.
+
+ plot_trajectories (bool):
+ Whether to plot the trajectories. The default is False.
+
+ batch_size (int):
+ The batch size to use for video inference. Only for PyTorch models.
+
+ detector_batch_size (int):
+ The batch size to use for the detector during video inference. Only for PyTorch.
+
+ pcutoff (float):
+ The p-value cutoff for the confidence of the prediction. The default is 0.1.
+
+ adapt_iterations (int):
+ Number of iterations for adaptation training. Empirically 1000 is sufficient.
+
+ bbox_threshold (float):
+ The pseudo-label threshold for the confidence of the detector. The default is 0.9
+
+ detector_epochs (int):
+ Used in the PyTorch engine. The number of epochs for training the detector. The default is 4.
+
+ pose_epochs (int):
+ Used in the PyTorch engine. The number of epochs for training the pose estimator. The default is 4.
+
+ pseudo_threshold (float):
+ The pseudo-label threshold for the confidence of the prediction. The default is 0.1.
+
+ max_individuals (int):
+ The maximum number of individuals in the video. The default is 30. Used only for top down models.
+
+ video_adapt_batch_size (int):
+ The batch size to use for video adaptation.
+
+ device (str):
+ The device to use for inference. The default is None (CPU). Used only for PyTorch models.
+
+ customized_pose_checkpoint (str):
+ Used in the PyTorch engine. If specified, it replaces the default pose checkpoint.
+
+ customized_detector_checkpoint (str):
+ Used in the PyTorch engine. If specified, it replaces the default detector checkpoint.
+
+ customized_model_config (str):
+ Used for loading customized model config. Only supported in Pytorch
+
+ plot_bboxes (bool):
+ If using Top-Down approach, whether to plot the detector's bounding boxes. The default is True.
+
+ Raises:
+ NotImplementedError:
+ If the model is not found in the modelzoo.
+ Warning: If the superanimal_name will be deprecated in the future.
+
+ (Model Explanation) SuperAnimal-Quadruped:
+ `superanimal_quadruped` models aim to work across a large range of quadruped
+ animals, from horses, dogs, sheep, rodents, to elephants. The camera perspective is
+ orthogonal to the animal ("side view"), and most of the data includes the animals
+ face (thus the front and side of the animal). You will note we have several variants
+ that differ in speed vs. performance, so please do test them out on your data to see
+ which is best suited for your application. Also note we have a "video adaptation"
+ feature, which lets you adapt your data to the model in a self-supervised way.
+ No labeling needed!
+
+ All model snapshots are automatically downloaded to modelzoo/checkpoints when used.
+
+ - PLEASE SEE THE FULL DATASHEET: https://zenodo.org/records/10619173
+ - MORE DETAILS ON THE MODELS (detector, pose estimators):
+ https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped
+ - We provide several models:
+ - `hrnet_w32` (Top-Down pose estimation model, PyTorch engine)
+ An `hrnet_w32` is a top-down model that is paired with a detector. That
+ means it takes a cropped image from an object detector and predicts the
+ keypoints. When selecting this variant, a `detector_name` must be set with
+ one of the provided object detectors.
+ - `dlcrnet` (TensorFlow engine)
+ This is a bottom-up model that predicts all keypoints then groups them into
+ individuals. This can be faster, but more error prone.
+ - We provide one object detector (only for the PyTorch engine):
+ - `fasterrcnn_resnet50_fpn_v2`
+ This is a FasterRCNN model with a ResNet backbone, see
+ https://pytorch.org/vision/stable/models/faster_rcnn.html
+
+ (Model Explanation) SuperAnimal-TopViewMouse:
+ `superanimal_topviewmouse` aims to work across lab mice in different lab settings
+ from a top-view perspective; this is very polar in many behavioral assays in freely
+ moving mice.
+
+ All model snapshots are automatically downloaded to modelzoo/checkpoints when used.
+
+ - [PLEASE SEE THE FULL DATASHEET HERE](https://zenodo.org/records/10618947)
+ - [MORE DETAILS ON THE MODELS (detector, pose estimators)](https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse)
+ - We provide several models:
+ - `hrnet_w32` (Top-Down pose estimation model, PyTorch engine)
+ An `hrnet_w32` is a top-down model that is paired with a detector. That
+ means it takes a cropped image from an object detector and predicts the
+ keypoints. When selecting this variant, a `detector_name` must be set with
+ one of the provided object detectors.
+ - `dlcrnet` (TensorFlow engine)
+ This is a bottom-up model that predicts all keypoints then groups them into
+ individuals. This can be faster, but more error prone.
+ - We provide one object detector (only for the PyTorch engine):
+ - `fasterrcnn_resnet50_fpn_v2`
+ This is a FasterRCNN model with a ResNet backbone, see
+ https://pytorch.org/vision/stable/models/faster_rcnn.html
+
+ (Model Explanation) SuperAnimal-Bird:
+ TODO(shaokai)
+
+ Examples (PyTorch Engine)
+ --------
+ >>> import deeplabcut.modelzoo.video_inference.video_inference_superanimal as video_inference_superanimal
+ >>> video_inference_superanimal(
+ videos=["/mnt/md0/shaokai/DLCdev/3mice_video1_short.mp4"],
+ superanimal_name="superanimal_topviewmouse",
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ video_adapt=True,
+ max_individuals=3,
+ pseudo_threshold=0.1,
+ bbox_threshold=0.9,
+ detector_epochs=4,
+ pose_epochs=4,
+ )
+
+ Tips:
+ * max_individuals: make sure you correctly give the number of individuals. Our
+ inference api will only give up to max_individuals number of predictions.
+ * pseudo_threshold: the higher you set, the more aggressive you filter low
+ confidence predictions during video adaptation.
+ * bbox_threshold: the higher you set, the more aggressive you filter low confidence
+ bounding boxes during video adaptation. Different from our paper, we now add
+ video adaptation to the object detector as well.
+ * detector_epochs and pose_epochs do not need to be to high as video adaptation does
+ not require too much training. However, you can make them higher if you see a
+ substaintial gain in the training logs.
+
+ Examples
+ --------
+
+ >>> from deeplabcut.modelzoo.video_inference import video_inference_superanimal
+ >>> videos = ["/path/to/my/video.mp4"]
+ >>> superanimal_name = "superanimal_topviewmouse"
+ >>> videotype = "mp4"
+ >>> scale_list = [200, 300, 400]
+ >>> video_inference_superanimal(
+ videos,
+ superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ scale_list = scale_list,
+ videotype = videotype,
+ video_adapt = True,
+ )
+
+ Tips:
+ scale_list: it's recommended to leave this as empty list []. Empirically
+ [200, 300, 400] works well. We needed to do this as bottom-up models in TensorFlow
+ are sensitive to the scales of the image.
+ If you find your predictions not good without scale_list or it's too hard to find
+ the right scale_list, you can try to use the PyTorch engine.
+ """
+ if scale_list is None:
+ scale_list = []
+
+ print(f"Running video inference on {videos} with {superanimal_name}_{model_name}")
+ dlc_root_path = get_deeplabcut_path()
+ modelzoo_path = os.path.join(dlc_root_path, "modelzoo")
+ available_architectures = json.load(
+ open(os.path.join(modelzoo_path, "models_to_framework.json"), "r")
+ )
+ framework = available_architectures[model_name]
+ print(f"Using {framework} for model {model_name}")
+ if framework == "tensorflow":
+ from deeplabcut.pose_estimation_tensorflow.modelzoo.api.superanimal_inference import (
+ _video_inference_superanimal,
+ )
+
+ weight_folder = get_snapshot_folder_path() / f"{superanimal_name}_{model_name}"
+ if not weight_folder.exists():
+ download_huggingface_model(
+ superanimal_name, target_dir=str(weight_folder), rename_mapping=None
+ )
+
+ if isinstance(videos, str):
+ videos = [videos]
+ _video_inference_superanimal(
+ videos,
+ superanimal_name,
+ model_name,
+ scale_list,
+ videotype,
+ video_adapt,
+ plot_trajectories,
+ pcutoff,
+ adapt_iterations,
+ pseudo_threshold,
+ )
+ elif framework == "pytorch":
+ if detector_name is None:
+ raise ValueError(
+ "You have to specify a detector_name when using the Pytorch framework."
+ )
+
+ from deeplabcut.pose_estimation_pytorch.modelzoo.inference import (
+ _video_inference_superanimal,
+ )
+
+ if customized_model_config is not None:
+ config = read_config_as_dict(customized_model_config)
+ else:
+ config = load_super_animal_config(
+ super_animal=superanimal_name,
+ model_name=model_name,
+ detector_name=detector_name,
+ )
+
+ pose_model_path = customized_pose_checkpoint
+ if pose_model_path is None:
+ pose_model_path = get_super_animal_snapshot_path(
+ dataset=superanimal_name,
+ model_name=model_name,
+ )
+
+ detector_path = customized_detector_checkpoint
+ if detector_path is None:
+ detector_path = get_super_animal_snapshot_path(
+ dataset=superanimal_name,
+ model_name=detector_name,
+ )
+
+ dlc_scorer = get_super_animal_scorer(
+ superanimal_name, pose_model_path, detector_path
+ )
+
+ config = update_config(config, max_individuals, device)
+ output_suffix = "_before_adapt"
+ if video_adapt:
+ # the users can pass in many videos. For now, we only use one video for
+ # video adaptation. As reported in Ye et al. 2024, one video should be
+ # sufficient for video adaptation.
+ video_path = Path(videos[0])
+ print(f"Using {video_path} for video adaptation training")
+
+ # video inference to get pseudo label
+ _video_inference_superanimal(
+ [str(video_path)],
+ superanimal_name,
+ model_cfg=config,
+ model_snapshot_path=pose_model_path,
+ detector_snapshot_path=detector_path,
+ max_individuals=max_individuals,
+ pcutoff=pcutoff,
+ batch_size=batch_size,
+ detector_batch_size=detector_batch_size,
+ cropping=cropping,
+ dest_folder=dest_folder,
+ output_suffix=output_suffix,
+ plot_bboxes=plot_bboxes,
+ bboxes_pcutoff=bbox_threshold,
+ )
+
+ # we prepare the pseudo dataset in the same folder of the target video
+ pseudo_dataset_folder = video_path.with_name(f"pseudo_{video_path.stem}")
+ pseudo_dataset_folder.mkdir(exist_ok=True)
+ model_folder = pseudo_dataset_folder / "checkpoints"
+ model_folder.mkdir(exist_ok=True)
+
+ image_folder = pseudo_dataset_folder / "images"
+ if image_folder.exists():
+ print(f"{image_folder} exists, skipping the frame extraction")
+ else:
+ image_folder.mkdir()
+ print(
+ f"Video frames being extracted to {image_folder} for video "
+ f"adaptation."
+ )
+ video_to_frames(video_path, pseudo_dataset_folder, cropping=cropping)
+
+ anno_folder = pseudo_dataset_folder / "annotations"
+ if (anno_folder / "train.json").exists() and (
+ anno_folder / "test.json"
+ ).exists():
+ print(
+ f"{anno_folder} exists, skipping the annotation construction. "
+ f"Delete the folder if you want to re-construct pseudo annotations"
+ )
+ else:
+ anno_folder.mkdir()
+
+ if dest_folder is None:
+ pseudo_anno_dir = video_path.parent
+ else:
+ pseudo_anno_dir = Path(dest_folder)
+
+ pseudo_anno_name = f"{video_path.stem}_{dlc_scorer}_before_adapt.json"
+ with open(pseudo_anno_dir / pseudo_anno_name, "r") as f:
+ predictions = json.load(f)
+
+ # make sure we tune parameters inside this function such as pseudo
+ # threshold etc.
+ print(f"Constructing pseudo dataset at {pseudo_dataset_folder}")
+ dlc3predictions_2_annotation_from_video(
+ predictions,
+ pseudo_dataset_folder,
+ config["metadata"]["bodyparts"],
+ superanimal_name,
+ pose_threshold=pseudo_threshold,
+ bbox_threshold=bbox_threshold,
+ )
+
+ model_snapshot_prefix = f"snapshot-{model_name}"
+ detector_snapshot_prefix = f"snapshot-{detector_name}"
+
+ config["runner"]["snapshot_prefix"] = model_snapshot_prefix
+ config["detector"]["runner"]["snapshot_prefix"] = detector_snapshot_prefix
+
+ # the model config's parameters need to be updated for adaptation training
+ model_config_path = model_folder / "pytorch_config.yaml"
+ with open(model_config_path, "w") as f:
+ yaml = YAML()
+ yaml.dump(config, f)
+
+ adapted_detector_checkpoint = (
+ model_folder / f"{detector_snapshot_prefix}-{detector_epochs:03}.pt"
+ )
+ adapted_pose_checkpoint = (
+ model_folder / f"{model_snapshot_prefix}-{pose_epochs:03}.pt"
+ )
+
+ if (
+ adapted_detector_checkpoint.exists()
+ and adapted_pose_checkpoint.exists()
+ ):
+ print(
+ f"Video adaptation already ran; pose ({adapted_pose_checkpoint}) "
+ f"and detector ({adapted_detector_checkpoint}) already exist. To "
+ "rerun video adaptation training, delete the checkpoints or select"
+ "a different number of adaptation epochs. Continuing with the"
+ "existing checkpoints."
+ )
+ else:
+ print(
+ "Running video adaptation with following parameters:\n"
+ f" (pose training) pose_epochs: {pose_epochs}\n"
+ " (pose) save_epochs: 1\n"
+ f" detector_epochs: {detector_epochs}\n"
+ " detector_save_epochs: 1\n"
+ f" video adaptation batch size: {video_adapt_batch_size}\n"
+ )
+ train_file = pseudo_dataset_folder / "annotations" / "train.json"
+ with open(train_file, "r") as f:
+ temp_obj = json.load(f)
+
+ annotations = temp_obj["annotations"]
+ if len(annotations) == 0:
+ print(
+ f"No valid predictions from {str(video_path)}. Check the "
+ "quality of the video"
+ )
+ return
+
+ adaptation_train(
+ project_root=pseudo_dataset_folder,
+ model_folder=model_folder,
+ train_file="train.json",
+ test_file="test.json",
+ model_config_path=model_config_path,
+ device=device,
+ epochs=pose_epochs,
+ save_epochs=1,
+ detector_epochs=detector_epochs,
+ detector_save_epochs=1,
+ snapshot_path=pose_model_path,
+ detector_path=detector_path,
+ batch_size=video_adapt_batch_size,
+ detector_batch_size=video_adapt_batch_size,
+ )
+
+ # Set the customized checkpoint paths and
+ output_suffix = "_after_adapt"
+ detector_path = adapted_detector_checkpoint
+ pose_model_path = adapted_pose_checkpoint
+
+ return _video_inference_superanimal(
+ videos,
+ superanimal_name,
+ model_cfg=config,
+ model_snapshot_path=pose_model_path,
+ detector_snapshot_path=detector_path,
+ max_individuals=max_individuals,
+ pcutoff=pcutoff,
+ batch_size=batch_size,
+ detector_batch_size=detector_batch_size,
+ cropping=cropping,
+ dest_folder=dest_folder,
+ output_suffix=output_suffix,
+ plot_bboxes=plot_bboxes,
+ bboxes_pcutoff=bbox_threshold,
+ )
diff --git a/deeplabcut/modelzoo/webapp/__init__.py b/deeplabcut/modelzoo/webapp/__init__.py
new file mode 100644
index 0000000000..117d127147
--- /dev/null
+++ b/deeplabcut/modelzoo/webapp/__init__.py
@@ -0,0 +1,10 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
diff --git a/deeplabcut/modelzoo/webapp/inference.py b/deeplabcut/modelzoo/webapp/inference.py
new file mode 100644
index 0000000000..806e7778f1
--- /dev/null
+++ b/deeplabcut/modelzoo/webapp/inference.py
@@ -0,0 +1,119 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from typing import Dict
+
+import numpy as np
+
+import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
+from deeplabcut.pose_estimation_pytorch.apis.utils import get_inference_runners
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import update_config
+
+
+class SingletonTopDownRunners:
+ """Singleton class for topdown runners
+
+ This class is a singleton class for topdown runners. It is used to
+ ensure that only one instance of the topdown runners is created.
+
+ Attrs:
+ config: Configuration dictionary
+ pose_model_path: Path to the pose model
+ detector_model_path: Path to the detector model
+ num_bodyparts: Number of bodyparts
+ max_individuals: Maximum number of individuals
+ """
+
+ _instance = None
+
+ def __new__(cls, *args, **kwargs):
+ if not cls._instance:
+ cls._instance = super().__new__(cls)
+ return cls._instance
+
+ def __init__(
+ self,
+ config,
+ pose_model_path: str,
+ detector_model_path: str,
+ num_bodyparts: int,
+ max_individuals: int,
+ ):
+
+ pose_runner, detector_runner = get_inference_runners(
+ config,
+ snapshot_path=pose_model_path,
+ max_individuals=max_individuals,
+ num_bodyparts=num_bodyparts,
+ num_unique_bodyparts=0,
+ detector_path=detector_model_path,
+ )
+ self.pose_runner = pose_runner
+ self.detector_runner = detector_runner
+
+
+class SuperanimalPyTorchInference:
+ """Superanimal inference class
+
+ This class is used to perform inference on a superanimal model from the
+ DeepLabCut model zoo website.
+ """
+
+ def __init__(
+ self,
+ project_name: str,
+ pose_model_type: str = "hrnet_w32",
+ detector_model_type: str = "fasterrcnn_resnet50_fpn_v2",
+ max_individuals: int = 30,
+ device: str = "cpu",
+ ):
+ self.max_individuals = max_individuals
+ config = modelzoo.load_super_animal_config(
+ super_animal=project_name,
+ model_name=pose_model_type,
+ detector_name=detector_model_type,
+ )
+ config = update_config(config, max_individuals, device)
+ self._config = config
+
+ def initialize_models(self, pose_model_path: str, detector_model_path: str):
+ self.models = SingletonTopDownRunners(
+ self.config,
+ pose_model_path,
+ detector_model_path,
+ len(self.config["bodyparts"]),
+ self.max_individuals,
+ )
+
+ @property
+ def config(self):
+ return self._config
+
+ def predict(self, frames: Dict[str, np.array]):
+
+ input_images = np.array(list(frames.values()), dtype=float)
+
+ bbox_predictions = self.models.detector_runner.inference(images=input_images)
+ input_images = list(zip(input_images, bbox_predictions))
+ predictions = self.models.pose_runner.inference(images=input_images)
+ predictions = [
+ {("markers" if k == "bodyparts" else k): v for k, v in d.items()}
+ for d in predictions
+ ]
+ predictions = [
+ {**item[1], "image_path": item[0]}
+ for item in zip(frames.keys(), predictions)
+ ]
+ responses = {
+ "joint_names": self.config["bodyparts"],
+ "predictions": predictions,
+ }
+
+ return responses
diff --git a/deeplabcut/modelzoo/weight_initialization.py b/deeplabcut/modelzoo/weight_initialization.py
new file mode 100644
index 0000000000..74558a7033
--- /dev/null
+++ b/deeplabcut/modelzoo/weight_initialization.py
@@ -0,0 +1,108 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Functions to build weight initialization parameters for SuperAnimal models"""
+from pathlib import Path
+
+import deeplabcut.modelzoo.utils as utils
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ get_super_animal_snapshot_path
+)
+
+
+def build_weight_init(
+ cfg: dict | str | Path,
+ super_animal: str,
+ model_name: str,
+ detector_name: str | None,
+ with_decoder: bool = False,
+ memory_replay: bool = False,
+ customized_pose_checkpoint: str | Path | None = None,
+ customized_detector_checkpoint: str | Path | None = None,
+) -> WeightInitialization:
+ """Builds the WeightInitialization from a SuperAnimal model for a project
+
+ Args:
+ cfg: The project's configuration, or the path to the project configuration file.
+ super_animal: The SuperAnimal model with which to initialize weights.
+ model_name: The type of the model architecture for which to load the weights.
+ detector_name: The type of detector architecture for which to load the weights.
+ with_decoder: Whether to load the decoder weights as well. If this is true,
+ a conversion table must be specified for the given SuperAnimal in the
+ project configuration file. See
+ ``deeplabcut.modelzoo.utils.create_conversion_table`` to create a
+ conversion table.
+ memory_replay: Only when ``with_decoder=True``. Whether to train the model
+ with memory replay, so that it predicts all SuperAnimal bodyparts.
+ customized_pose_checkpoint: A customized SuperAnimal pose checkpoint, as an
+ alternative to the Hugging Face one
+ customized_detector_checkpoint: A customized SuperAnimal detector checkpoint, as
+ an alternative to the Hugging Face one
+
+ To build a WeightInitialization instance for a project using the conversion table
+ specified in the project configuration file, use:
+
+ ```
+ from pathlib import Path
+ from deeplabcut.utils.auxiliaryfunctions import read_config
+ from deeplabcut.modelzoo import build_weight_init
+
+ project_cfg = read_config("/path/to/my/project/config.yaml")
+ super_animal = "superanimal_quadruped"
+ weight_init = build_weight_init(
+ cfg=project_cfg,
+ super_animal="superanimal_quadruped",
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ with_decoder=True,
+ memory_replay=False,
+ )
+ ```
+
+ Returns:
+ The built WeightInitialization.
+ """
+ if isinstance(cfg, (str, Path)):
+ cfg = read_config_as_dict(cfg)
+
+ conversion_array = None
+ bodyparts = None
+ if with_decoder:
+ conversion_table = utils.get_conversion_table(cfg, super_animal)
+ conversion_array = conversion_table.to_array()
+ bodyparts = conversion_table.converted_bodyparts()
+
+ snapshot_path = customized_pose_checkpoint
+ if snapshot_path is None:
+ snapshot_path = get_super_animal_snapshot_path(
+ dataset=super_animal,
+ model_name=model_name,
+ download=True,
+ )
+
+ detector_snapshot_path = customized_detector_checkpoint
+ if detector_snapshot_path is None and detector_name is not None:
+ detector_snapshot_path = get_super_animal_snapshot_path(
+ dataset=super_animal,
+ model_name=detector_name,
+ download=True,
+ )
+
+ return WeightInitialization(
+ snapshot_path=snapshot_path,
+ detector_snapshot_path=detector_snapshot_path,
+ dataset=super_animal,
+ with_decoder=with_decoder,
+ memory_replay=memory_replay,
+ conversion_array=conversion_array,
+ bodyparts=bodyparts,
+ )
diff --git a/deeplabcut/pose_estimation_3d/triangulation.py b/deeplabcut/pose_estimation_3d/triangulation.py
index 36d04c5b9c..a0ebf98a16 100644
--- a/deeplabcut/pose_estimation_3d/triangulation.py
+++ b/deeplabcut/pose_estimation_3d/triangulation.py
@@ -4,24 +4,20 @@
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
import os
from pathlib import Path
-
import cv2
import numpy as np
import pandas as pd
-from matplotlib.axes._axes import _log as matplotlib_axes_logger
from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions
from deeplabcut.utils import auxiliaryfunctions_3d
-from deeplabcut.pose_estimation_tensorflow.lib.trackingutils import TRACK_METHODS
-
-matplotlib_axes_logger.setLevel("ERROR")
+from deeplabcut.core.trackingutils import TRACK_METHODS
def triangulate(
@@ -88,7 +84,7 @@ def triangulate(
To analyze only a few pair of videos:
>>> deeplabcut.triangulate(config,[['C:\\yourusername\\rig-95\\Videos\\video1-camera-1.avi','C:\\yourusername\\rig-95\\Videos\\video1-camera-2.avi'],['C:\\yourusername\\rig-95\\Videos\\video2-camera-1.avi','C:\\yourusername\\rig-95\\Videos\\video2-camera-2.avi']])
"""
- from deeplabcut.pose_estimation_tensorflow import predict_videos
+ from deeplabcut.compat import analyze_videos
from deeplabcut.post_processing import filtering
cfg_3d = auxiliaryfunctions.read_config(config)
@@ -263,7 +259,7 @@ def triangulate(
scorer_name[cam_names[j]] = DLCscorer
else:
# Analyze video if score name is different
- DLCscorer = predict_videos.analyze_videos(
+ DLCscorer = analyze_videos(
config_2d,
[video],
videotype=videotype,
@@ -293,7 +289,7 @@ def triangulate(
)
else: # need to do the whole jam.
- DLCscorer = predict_videos.analyze_videos(
+ DLCscorer = analyze_videos(
config_2d,
[video],
videotype=videotype,
@@ -446,18 +442,32 @@ def triangulate(
}
# Create 3D DataFrame column and row indices
- axis_labels = ("x", "y", "z")
+ cols = [
+ [scorer_3d],
+ list(auxiliaryfunctions.get_bodyparts(cfg)),
+ ["x", "y", "z"],
+ ]
+ cols_names = ["scorer", "bodyparts", "coords"]
+ flag_indiv_single = False
if cfg.get("multianimalproject"):
- columns = pd.MultiIndex.from_product(
- [[scorer_3d], individuals, bodyparts, axis_labels],
- names=["scorer", "individuals", "bodyparts", "coords"],
- )
-
- else:
- columns = pd.MultiIndex.from_product(
- [[scorer_3d], bodyparts, axis_labels],
- names=["scorer", "bodyparts", "coords"],
- )
+ cols_names.insert(1, "individuals")
+ if "single" == individuals[-1]:
+ individuals = individuals[:-1]
+ columns_unique = pd.MultiIndex.from_product(
+ [
+ [scorer_3d],
+ ["single"],
+ auxiliaryfunctions.get_unique_bodyparts(cfg),
+ ["x", "y", "z"],
+ ],
+ names=cols_names,
+ )
+ flag_indiv_single = True
+ cols.insert(1, individuals)
+ columns = pd.MultiIndex.from_product(cols, names=cols_names)
+ if flag_indiv_single:
+ columns = columns.append(columns_unique)
+ individuals.append("single")
inds = range(num_frames)
@@ -468,10 +478,10 @@ def triangulate(
df_3d = pd.DataFrame(triangulate, columns=columns, index=inds)
df_3d.to_hdf(
- str(output_filename + ".h5"),
- "df_with_missing",
- format="table",
+ str(output_filename) + ".h5",
+ key="df_with_missing",
mode="w",
+ format="table",
)
# Reorder 2D dataframe in view 2 to match order of view 1
@@ -483,19 +493,19 @@ def triangulate(
)
df_2d_view2.to_hdf(
dataname[1],
- "tracks",
+ key="tracks",
format="table",
mode="w",
)
auxiliaryfunctions_3d.SaveMetadata3d(
- str(output_filename + "_meta.pickle"), metadata
+ str(output_filename) + "_meta.pickle", metadata
)
if save_as_csv:
- df_3d.to_csv(str(output_filename + ".csv"))
+ df_3d.to_csv(str(output_filename) + ".csv")
- print("Triangulated data for video", video_list[i])
+ print("Triangulated data for video", video)
print("Results are saved under: ", destfolder)
# have to make the dest folder none so that it can be updated for a new pair of videos
if destfolder == str(Path(video).parents[0]):
diff --git a/deeplabcut/pose_estimation_pytorch/README.md b/deeplabcut/pose_estimation_pytorch/README.md
new file mode 100644
index 0000000000..5ec580d787
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/README.md
@@ -0,0 +1,510 @@
+# PyTorch DeepLabCut API
+
+This overview is primarily written for maintainers and expert users.
+
+Here we detail the logic and structure for the DLC3.* PyTorch code. Furthermore, we
+provide many practical examples to illustrate the usage of the code for developers.
+
+## Structure of the PyTorch DLC code
+
+[API](#API)
+
+[Models](#models)
+
+[Data](#data)
+
+[Runners](#runners)
+
+### API
+
+High-level API methods are implemented in `deeplabcut.pose_estimations_pytorch.apis`.
+This folder includes methods to train and evaluate models on DeepLabCut projects, and
+analyze videos or folders (of images). While some of the methods are implemented to work
+directly from DeepLabCut projects (i.e. by specifying the path to the project config
+file and the shuffle number), internally they call methods that allow more flexibility.
+Thus, they are also ideally suited for developers.
+
+### Models
+
+We provide state-of-the-art pose estimation models such as DLCRNet, HRNet, DEKR, BUCTD
+and more are coming! Object detection models are also available (and implemented in
+`deeplabcut.pose_estimations_pytorch.models.detectors`).
+
+The `deeplabcut.pose_estimations_pytorch.models` package contains all components related
+to building a model. Models are flexibly build from modular components: `backbone`,
+`neck` (optional) and `head` (as discussed below).
+
+You can check available models by running:
+
+```python
+import deeplabcut.pose_estimation_pytorch
+
+# Available pose estimation models
+print(deeplabcut.pose_estimation_pytorch.available_models())
+
+# Available object detection models
+print(deeplabcut.pose_estimation_pytorch.available_detectors())
+```
+
+#### Model Configuration Files
+
+Model architectures are built according to a configuration specified in a `yaml` file.
+This file (named `pytorch_cfg.yaml`) describes the architecture of the model you want to
+train (but also hyperparameters, optimizer, ...). All code to manipulate PyTorch
+configuration files is in `deeplabcut.pose_estimations_pytorch.config`.
+
+To generate a model configuration, you can call `make_pytorch_pose_config`. Note that
+this does not save the configuration to a given filepath - it just returns it as a
+dictionary. However, you can save it with `write_config`.
+
+During a typical DeepLabCut project management workflow, these methods don't need to be
+called, as `create_training_dataset` will create this configuration file and save it to
+disk.
+
+```python
+from pathlib import Path
+
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+project_cfg = { "Task": "mice", ... } # the configuration for your DLC project
+pose_config_path = Path("/path/to/my/config/pytorch_cfg.yaml")
+model_cfg = dlc_torch.config.make_pytorch_pose_config(
+ project_config=project_cfg,
+ pose_config_path=pose_config_path,
+ net_type="hrnet_w32",
+ top_down=True,
+ save=True,
+)
+```
+
+#### Adding Models
+
+If you want to add a novel model, you'll ideally build them from the following
+implemented parts:
+
+- a backbone (such as a ResNet or HRNet)
+- a head (such as a HeatmapHead)
+- a predictor (transforming model outputs into keypoint locations)
+- a target generator (creating the targets for your head outputs from your labels)
+
+Some models can also define a neck (model components between the backbone and the head).
+You'll also need some loss criterions, but usually you'll be able to use existing ones.
+
+You can either use existing classes and only replace some elements, or rewrite
+everything you need for your model. We use Model Registries to simplify the process of
+adding models.
+
+#### Model Registry
+
+Registries are created for all model building blocks to make it easy to add new models.
+All you need to do is add the decorator `REGISTRY.register_module` to be able to load
+your model from a configuration file. Available registries are `BACKBONES`, `NECKS`,
+`HEADS`, `PREDICTORS` and `TARGET_GENERATORS`. Each building block has a base class
+that should be inherited by the class added to the model registry (`BaseBackbone`,
+`BaseNeck`, `BaseHead`, `BasePredictor` and `BaseGenerator` respectively).
+
+Let's illustrate that with a small example. We'll create a dummy backbone, which simply
+applies a max-pool to the input:
+
+```python
+import torch
+import torch.nn.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.backbones import BACKBONES, BaseBackbone
+
+
+@BACKBONES.register_module
+class DummyBackbone(BaseBackbone):
+ """A dummy backbone, simply max-pooling the input"""
+
+ def __init__(self, kernel_size: int = 2):
+ super().__init__(stride=kernel_size)
+ self.kernel_size = kernel_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.max_pool2d(x, kernel_size=self.kernel_size)
+
+
+backbone_config = dict(type="DummyBackbone", kernel_size=3)
+backbone = BACKBONES.build(backbone_config) # will create a DummyBackbone
+```
+
+Another example would be creating a custom head for our model. In this case, let's make
+a head which takes as input the output of a backbone (which has shape `(num_channels,
+H', W')`) and put it through a kernel-size 1 convolution, simply changing the number of
+channels.
+
+Heads can output multiple tensors (such as heatmaps and location refinement fields).
+Therefore, their `forward(...)` method outputs a dictionary mapping strings to tensors.
+Here, we return the `heatmap` and `locref` tensors.
+
+A head must contain different: a `target_generator` to generate targets for
+its outputs and a `predictor` to convert model outputs to pose. Make sure that the keys
+output by the `target_generator` and the `head` match! Some `criterion` also needs to be
+defined to compute the loss between the outputs and targets. When more than one output
+is specified (such as in this case, where we're generating heatmaps and location
+refinement fields), a loss aggregator must also be given to combine all losses into one
+(this should simply be a `WeightedLossAggregator`, indicating the weight for each loss).
+
+```python
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+ WeightedHuberCriterion,
+ WeightedLossAggregator,
+ WeightedMSECriterion,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads import HEADS, BaseHead
+from deeplabcut.pose_estimation_pytorch.models.predictors import (
+ BasePredictor,
+ HeatmapPredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.target_generators import (
+ BaseGenerator,
+ HeatmapGaussianGenerator,
+)
+
+
+@HEADS.register_module
+class DummyHead(BaseHead):
+ """A dummy backbone, simply max-pooling the input"""
+
+ def __init__(
+ self,
+ num_input_channels: int,
+ num_bodyparts: int,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion],
+ aggregator: BaseLossAggregator,
+ ):
+ super().__init__(
+ stride=1,
+ predictor=predictor,
+ target_generator=target_generator,
+ criterion=criterion,
+ aggregator=aggregator
+ )
+ self.conv_heatmap = nn.Conv2d(
+ in_channels=num_input_channels,
+ out_channels=num_bodyparts,
+ kernel_size=1,
+ stride=1,
+ )
+ self.locref_heatmap = nn.Conv2d(
+ in_channels=num_input_channels,
+ out_channels=2 * num_bodyparts,
+ kernel_size=1,
+ stride=1,
+ )
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ return {
+ "heatmap": self.conv_heatmap(x),
+ "locref": self.locref_heatmap(x),
+ }
+
+
+head_config = dict(
+ type="DummyHead",
+ num_input_channels=2048,
+ num_bodyparts=5,
+ predictor=HeatmapPredictor(location_refinement=True, locref_std= 7.2801),
+ target_generator=HeatmapGaussianGenerator(
+ num_heatmaps=5,
+ pos_dist_thresh=17,
+ heatmap_mode=HeatmapGaussianGenerator.Mode.KEYPOINT,
+ generate_locref=True,
+ ),
+ criterion={
+ "heatmap": WeightedMSECriterion(),
+ "locref": WeightedHuberCriterion(),
+ },
+ aggregator=WeightedLossAggregator(weights={"heatmap": 1, "locref": 0.05}),
+)
+head = HEADS.build(head_config)
+```
+
+### Data
+
+The `deeplabcut.pose_estimations_pytorch.data` package contains all code for PyTorch
+dataset creation and test/train splitting. The `DLCLoader` class is used to load the
+labeled data for a specific shuffle.
+
+```python3
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+loader = dlc_torch.DLCLoader(
+ config="/path/to/my/project/config.yaml",
+ trainset_index=0,
+ shuffle=1,
+)
+
+# print the path to the model folder (where the config file is stored)
+print(loader.model_folder)
+# print the path to the evaluation folder
+print(loader.evaluation_folder)
+
+# display the DataFrame containing the dataset
+print(loader.df)
+
+# display the DataFrames containing the train/test data respectively
+print(loader.df_train)
+print(loader.df_test)
+```
+
+The `PoseDataset` class is an instance of
+[torch.utils.Dataset](https://pytorch.org/docs/stable/data.html), which converts raw
+images and keypoints to a tensor dataset for training and evaluation. You can generate
+an instance of training/test dataset with your `DLCLoader`:
+
+```python3
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+loader = dlc_torch.DLCLoader(
+ config="/path/to/my/project/config.yaml",
+ trainset_index=0,
+ shuffle=1,
+)
+train_dataset = loader.create_dataset(
+ transform=dlc_torch.build_transforms(loader.model_cfg["data"]["train"]),
+ mode="train",
+ task=loader.pose_task,
+)
+valid_dataset = loader.create_dataset(
+ transform=dlc_torch.build_transforms(loader.model_cfg["data"]["train"]),
+ mode="test",
+ task=loader.pose_task,
+)
+```
+
+A `COCOLoader` is also available, and allows you train models in DeepLabCut on
+[COCO-format](https://medium.com/@manuktiwary/coco-format-what-and-how-5c7d22cf5301)
+datasets. This essentially consists of having a folder containing your dataset in the
+format:
+
+```
+COCOProject
+└───annotations
+│ │ train.json
+│ │ test.json
+│
+└───images
+ │ img0000.png
+ │ img0001.png
+ │ ...
+```
+
+In your `train.json` and `test.json` files, you can either specify your image
+`"file_name"` with a relative path or with an absolute path. If a relative path is
+used (e.g. `img0000.png` or `subfolder/img0000.png`), it will be resolved to the
+`images` folder in your project (i.e. `/path/to/COCOProject/images/img0000.png` or
+`/path/to/COCOProject/images/subfolder/img0000.png`).
+
+If you specify an absolute path, the path to the image will not be resolved, and the
+image will be loaded from the specified path. This allows you to keep data on different
+disks, or reuse the same images in different projects without having to duplicate them.
+
+To train a DeepLabCut model on a COCO-format dataset, you'll need to specify a model
+configuration file (as described in [#model_configuration_files]).
+
+```python3
+from pathlib import Path
+
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+# Specify project paths
+project_root = Path("/path/to/my/COCOProject")
+train_json_filename = "train.json"
+test_json_filename = "test.json"
+
+# Parse information about the project
+train_dict = dlc_torch.COCOLoader.load_json(project_root, filename=train_json_filename)
+max_num_individuals, bodyparts = dlc_torch.COCOLoader.get_project_parameters(train_dict)
+
+# Generate a configuration file for your PyTorch model
+# In this case, it's for a Top-Down HRNet_w32
+experiment_path = project_root / "experiments" / "hrnet_w32"
+model_cfg_path = experiment_path / "train" / "pytorch_cfg.yaml"
+model_cfg = dlc_torch.config.make_pytorch_pose_config(
+ project_config=dlc_torch.config.make_basic_project_config(
+ dataset_path=str(project_root.resolve()),
+ bodyparts=bodyparts,
+ max_individuals=max_num_individuals,
+ multi_animal=True,
+ ),
+ pose_config_path=experiment_path,
+ net_type="hrnet_w32",
+ top_down=True,
+ save=True,
+)
+
+# Create the loader for the COCO dataset
+loader = dlc_torch.COCOLoader(
+ project_root=project_root,
+ model_config_path="/path/to/my/project/experiments/pytorch_config.yaml",
+ train_json_filename=train_json_filename,
+ test_json_filename=test_json_filename,
+)
+train_dataset = loader.create_dataset(
+ transform=dlc_torch.build_transforms(loader.model_cfg["data"]["train"]),
+ mode="train",
+ task=loader.pose_task,
+)
+valid_dataset = loader.create_dataset(
+ transform=dlc_torch.build_transforms(loader.model_cfg["data"]["train"]),
+ mode="test",
+ task=loader.pose_task,
+)
+```
+
+### Runners
+
+The `deeplabcut.pose_estimations_pytorch.runners` contains code to get models, load
+pretrained weights, and either train them or run inference with them.
+
+## Code Examples
+
+### Training a Model on a COCO Dataset
+
+```python
+from pathlib import Path
+
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+# Specify project paths
+project_root = Path("/path/to/my/COCOProject")
+train_json_filename = "train.json"
+test_json_filename = "test.json"
+
+loader = dlc_torch.COCOLoader(
+ project_root=project_root,
+ model_config_path="/path/to/my/project/experiments/pytorch_config.yaml",
+ train_json_filename=train_json_filename,
+ test_json_filename=test_json_filename,
+)
+dlc_torch.train(
+ loader=loader,
+ run_config=loader.model_cfg,
+ task=dlc_torch.Task(loader.model_cfg["method"]),
+ device="cuda:2",
+ logger_config=dict(
+ type="WandbLogger",
+ project_name="MyWandbProject",
+ tags=["model=hrnet_w32"],
+ ),
+ snapshot_path=None,
+)
+```
+
+### Running Video Analysis outside a DeepLabCut Project
+
+DeepLabCut provides high-level APIs (via the GUI or the python package) to analyze your
+data. The usage of this API assumes the existence of a DLC project (with `config.yaml`
+file, etc.).
+
+Sometimes it might be more convenient to just run a model on your data via a low-level
+API. We also use this API under the hood, in particular for the Model Zoo. Check out the
+example below:
+
+```python
+from deeplabcut.core.config import read_config_as_dict
+from pathlib import Path
+
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+
+train_dir = Path("/Users/Jaylen/my-dlc-models/train")
+pytorch_config_path = train_dir / "pytorch_config.yaml"
+snapshot_path = train_dir / "snapshot-100.pt"
+
+# for top-down models, otherwise None
+detector_snapshot_path = train_dir / "detector-snapshot-100.pt"
+
+# video and inference parameters
+video_path = Path("/Users/Jaylen/my-dlc-models/videos/test-video.mp4")
+max_num_animals = 5
+batch_size = 16
+detector_batch_size = 8
+
+# read model configuration
+model_cfg = read_config_as_dict(pytorch_config_path)
+pose_task = dlc_torch.Task(model_cfg["method"])
+pose_runner = dlc_torch.get_pose_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=snapshot_path,
+ max_individuals=max_num_animals,
+ batch_size=batch_size,
+)
+
+detector_runner = None
+if pose_task == dlc_torch.Task.TOP_DOWN:
+ detector_runner = dlc_torch.get_detector_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=detector_snapshot_path,
+ max_individuals=max_num_animals,
+ batch_size=detector_batch_size,
+ )
+
+predictions = dlc_torch.video_inference(
+ video=video_path,
+ pose_runner=pose_runner,
+ detector_runner=detector_runner,
+)
+```
+
+
+### Running Top-Down Video Analysis with Existing Bounding Boxes
+
+When `deeplabcut.pose_estimation_pytorch.apis.videos.video_inference` is called
+with a top-down model, it is assumed that a detector snapshot is given as well to obtain
+bounding boxes with which to run pose estimation. It's possible that you've already
+obtained bounding boxes for your video (with another object detector or through some
+other means), and you want to reuse those bounding boxes instead of running an object
+detector again.
+
+You can easily do so by writing a bit of custom code, as shown in the example below:
+
+```python
+from deeplabcut.core.config import read_config_as_dict
+from pathlib import Path
+
+import numpy as np
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+from tqdm import tqdm
+
+# create an iterator for your video
+video = dlc_torch.VideoIterator("/Users/Jayson/my-cool-video.mp4")
+
+# dummy bboxes - you can load yours from a file or in another way
+# the bboxes should be in `xywh` format, i.e. (x_top_left, y_top_left, width, height)
+bounding_boxes = [
+ dict( # frame 0 bounding boxes
+ bboxes=np.array([[12, 37, 120, 78]]),
+ ),
+ dict( # frame 1 bounding boxes
+ bboxes=np.array([[17, 45, 128, 73], [532, 34, 117, 87]]),
+ ),
+ # ...
+ dict( # frame N bboxes -> must be equal to the number of frames in the video!
+ bboxes=np.array([[17, 45, 128, 73], [532, 34, 117, 87]]),
+ ),
+]
+video.set_context(bounding_boxes)
+max_individuals = np.max([len(context["bboxes"]) for context in bounding_boxes])
+
+# run inference!
+model_cfg = read_config_as_dict("/Users/Jayson/pytorch_config.yaml")
+pose_runner = dlc_torch.get_pose_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=Path("/Users/Jayson/model-snapshot.pt"),
+ max_individuals=max_individuals,
+ batch_size=32,
+)
+
+# your predictions will be a list, containing the predictions made for each frame
+# as a dict (with keys for "bodyparts" but also "bboxes")!
+predictions = pose_runner.inference(images=tqdm(video))
+```
diff --git a/deeplabcut/pose_estimation_pytorch/__init__.py b/deeplabcut/pose_estimation_pytorch/__init__.py
new file mode 100644
index 0000000000..999ff14ab6
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/__init__.py
@@ -0,0 +1,61 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import deeplabcut.pose_estimation_pytorch.config as config
+from deeplabcut.pose_estimation_pytorch.apis import (
+ analyze_image_folder,
+ analyze_images,
+ analyze_videos,
+ build_predictions_dataframe,
+ create_labeled_images,
+ create_tracking_dataset,
+ convert_detections2tracklets,
+ evaluate,
+ evaluate_network,
+ extract_maps,
+ extract_save_all_maps,
+ get_detector_inference_runner,
+ get_pose_inference_runner,
+ predict,
+ superanimal_analyze_images,
+ train,
+ train_network,
+ video_inference,
+ VideoIterator,
+ visualize_predictions,
+)
+from deeplabcut.pose_estimation_pytorch.config import (
+ available_detectors,
+ available_models,
+)
+from deeplabcut.pose_estimation_pytorch.data import (
+ build_transforms,
+ COCOLoader,
+ COLLATE_FUNCTIONS,
+ DLCLoader,
+ Loader,
+ PoseDataset,
+ PoseDatasetParameters,
+)
+from deeplabcut.pose_estimation_pytorch.runners import (
+ build_inference_runner,
+ build_training_runner,
+ DetectorInferenceRunner,
+ DetectorTrainingRunner,
+ get_load_weights_only,
+ InferenceRunner,
+ PoseInferenceRunner,
+ PoseTrainingRunner,
+ set_load_weights_only,
+ TorchSnapshotManager,
+ TrainingRunner,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.pose_estimation_pytorch.utils import fix_seeds
diff --git a/deeplabcut/pose_estimation_pytorch/apis/__init__.py b/deeplabcut/pose_estimation_pytorch/apis/__init__.py
new file mode 100644
index 0000000000..de20610fcf
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/__init__.py
@@ -0,0 +1,54 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+from deeplabcut.pose_estimation_pytorch.apis.analyze_images import (
+ analyze_image_folder,
+ analyze_images,
+ analyze_image_folder,
+ superanimal_analyze_images,
+)
+from deeplabcut.pose_estimation_pytorch.apis.videos import (
+ analyze_videos,
+ video_inference,
+ VideoIterator,
+)
+from deeplabcut.pose_estimation_pytorch.apis.tracklets import (
+ convert_detections2tracklets,
+)
+from deeplabcut.pose_estimation_pytorch.apis.evaluation import (
+ predict,
+ evaluate,
+ evaluate_network,
+ visualize_predictions,
+)
+from deeplabcut.pose_estimation_pytorch.apis.export import export_model
+from deeplabcut.pose_estimation_pytorch.apis.tracking_dataset import (
+ create_tracking_dataset,
+)
+from deeplabcut.pose_estimation_pytorch.apis.training import (
+ train,
+ train_network,
+)
+from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ get_detector_inference_runner,
+ get_inference_runners,
+ get_pose_inference_runner,
+)
+from deeplabcut.pose_estimation_pytorch.apis.visualization import (
+ create_labeled_images,
+ extract_maps,
+ extract_save_all_maps,
+)
+from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ build_predictions_dataframe,
+ get_detector_inference_runner,
+ get_pose_inference_runner,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/apis/analyze_images.py b/deeplabcut/pose_estimation_pytorch/apis/analyze_images.py
new file mode 100644
index 0000000000..3a7fc15ca2
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/analyze_images.py
@@ -0,0 +1,619 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import copy
+import glob
+import json
+import logging
+import os
+from collections import defaultdict
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+from tqdm import tqdm
+
+import deeplabcut.core.config as config_utils
+import deeplabcut.pose_estimation_pytorch.apis.visualization as visualization
+import deeplabcut.pose_estimation_pytorch.data as data
+import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
+from deeplabcut.core.engine import Engine
+from deeplabcut.modelzoo.utils import get_superanimal_colormaps
+from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ get_detector_inference_runner,
+ build_predictions_dataframe,
+ get_model_snapshots,
+ get_pose_inference_runner,
+ get_scorer_name,
+ get_scorer_uid,
+ parse_snapshot_index_for_analysis,
+)
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import update_config
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.pose_estimation_pytorch.utils import resolve_device
+from deeplabcut.utils import auxfun_videos, auxiliaryfunctions
+
+
+def superanimal_analyze_images(
+ superanimal_name: str,
+ model_name: str,
+ detector_name: str,
+ images: str | Path | list[str] | list[Path],
+ max_individuals: int,
+ out_folder: str | Path,
+ progress_bar: bool = True,
+ device: str | None = None,
+ pose_threshold: float = 0.4,
+ bbox_threshold: float = 0.6,
+ plot_skeleton: bool = True,
+ customized_model_config: str | Path | dict | None = None,
+ customized_pose_checkpoint: str | Path | None = None,
+ customized_detector_checkpoint: str | Path | None = None,
+) -> dict[str, dict]:
+ """
+ This function inferences a superanimal model on a set of images and saves the
+ results as labeled images.
+
+ Args:
+ superanimal_name: str
+ The name of the SuperAnimal to analyze. Supported list:
+ - "superanimal_bird"
+ - "superanimal_topviewmouse"
+ - "superanimal_quadruped"
+
+ model_name: str
+ The name of the pose model architecture to use for inference. To get a list
+ of available models for a SuperAnimal, call:
+ >>> import dlclibrary
+ >>> superanimal_name = "superanimal_topviewmouse"
+ >>> dlclibrary.get_available_models(superanimal_name)
+
+ detector_name: str
+ The name of the detector architecture to use for inference. To get a list
+ of available detectors for a SuperAnimal, call:
+ >>> import dlclibrary
+ >>> superanimal_name = "superanimal_topviewmouse"
+ >>> dlclibrary.get_available_detectors(superanimal_name)
+
+ images: str, Path, list[str], list[Path]
+ The images to analyze. Can either be a directory containing images, or
+ a list of paths of images.
+
+ max_individuals: int
+ The maximum number of individuals to detect in each image.
+
+ out_folder: str | Path
+ The directory where the labeled images will be saved.
+
+ progress_bar: bool, default=True
+ Whether to display a progress bar when running inference.
+
+ device: str | None, default=None
+ The device to use to run image analysis.
+
+ pose_threshold: float, default=0.4
+ The cutoff score when plotting pose predictions. To note, this is called
+ pcutoff in other parts of the code. Must be in (0, 1).
+
+ bbox_threshold: float, default=0.1
+ The minimum confidence score to keep bounding box detections. Must be in
+ (0, 1).
+
+ plot_skeleton: bool, default=True
+ If a skeleton is defined in the model configuration file, whether to plot
+ the skeleton connecting the predicted bodyparts on the images.
+
+ customized_model_config: str | Path | dict | None
+ A customized SuperAnimal model config, as an alternative to the default
+ SuperAnimal model config. You can get the default SuperAnimal config with:
+ >>> import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
+ >>> config = modelzoo.load_super_animal_config(
+ >>> super_animal, model_name, detector_name,
+ >>> )
+
+ customized_pose_checkpoint: str | None
+ A customized SuperAnimal pose checkpoint, as an alternative to the
+ HuggingFace SuperAnimal models.
+
+ customized_detector_checkpoint: str | None
+ A customized SuperAnimal detector checkpoint, as an alternative to the
+ HuggingFace SuperAnimal models.
+
+ Returns:
+ The predictions made by the model for each image.
+
+ Examples:
+ >>> from deeplabcut.pose_estimation_pytorch.apis import (
+ >>> superanimal_analyze_images
+ >>> )
+ >>> predictions = superanimal_analyze_images(
+ >>> superanimal_name="superanimal_topviewmouse",
+ >>> model_name="resnet_50",
+ >>> detector_name="fasterrcnn_mobilenet_v3_large_fpn",
+ >>> images="test_mouse_images",
+ >>> max_individuals=3,
+ >>> out_folder="test_mouse_images_labeled",
+ >>> device="cuda:0",
+ >>> pose_threshold=0.1,
+ >>> )
+ """
+ out_folder = Path(out_folder)
+ out_folder.mkdir(exist_ok=True, parents=True)
+
+ if customized_pose_checkpoint is None:
+ snapshot_path = modelzoo.get_super_animal_snapshot_path(
+ dataset=superanimal_name,
+ model_name=model_name,
+ )
+ else:
+ snapshot_path = Path(customized_pose_checkpoint)
+
+ if customized_detector_checkpoint is None:
+ detector_path = modelzoo.get_super_animal_snapshot_path(
+ dataset=superanimal_name,
+ model_name=detector_name,
+ )
+ else:
+ detector_path = Path(customized_detector_checkpoint)
+
+ if customized_model_config is None:
+ config = modelzoo.load_super_animal_config(
+ super_animal=superanimal_name,
+ model_name=model_name,
+ detector_name=detector_name,
+ )
+ elif isinstance(customized_model_config, (str, Path)):
+ config = config_utils.read_config_as_dict(customized_model_config)
+ else:
+ config = copy.deepcopy(customized_model_config)
+
+ config = update_config(config, max_individuals, device)
+ config["metadata"]["individuals"] = [f"animal{i}" for i in range(max_individuals)]
+ if "detector" in config:
+ config["detector"]["model"]["box_score_thresh"] = bbox_threshold
+
+ predictions = analyze_image_folder(
+ model_cfg=config,
+ images=images,
+ snapshot_path=snapshot_path,
+ detector_path=detector_path,
+ max_individuals=max_individuals,
+ device=device,
+ progress_bar=progress_bar,
+ )
+
+ skeleton_bodyparts = config.get("skeleton", [])
+ skeleton = None
+ if plot_skeleton and len(skeleton_bodyparts) > 0:
+ skeleton = []
+ bodyparts = config["metadata"]["bodyparts"]
+ for bpt_0, bpt_1 in skeleton_bodyparts:
+ skeleton.append(
+ (bodyparts.index(bpt_0), bodyparts.index(bpt_1))
+ )
+
+ visualization.create_labeled_images(
+ predictions=predictions,
+ out_folder=out_folder,
+ pcutoff=pose_threshold,
+ bboxes_pcutoff=bbox_threshold,
+ cmap=get_superanimal_colormaps()[superanimal_name],
+ skeleton=skeleton,
+ skeleton_color=config.get("skeleton_color", "black"),
+ close_figure_after_save=False,
+ )
+
+ return predictions
+
+
+def analyze_images(
+ config: str | Path,
+ images: str | Path | list[str] | list[Path],
+ frame_type: str | None = None,
+ output_dir: str | Path | None = None,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ snapshot_index: int | None = None,
+ detector_snapshot_index: int | None = None,
+ modelprefix: str = "",
+ device: str | None = None,
+ max_individuals: int | None = None,
+ save_as_csv: bool = False,
+ progress_bar: bool = True,
+ plotting: bool | str = False,
+ pcutoff: float | None = None,
+ bbox_pcutoff: float | None = None,
+ plot_skeleton: bool = True,
+) -> dict[str, dict]:
+ """Runs analysis on images using a pose model.
+
+ Args:
+ config: The project configuration file.
+ images: The image(s) to run inference on. Can be the path to an image, the path
+ to a directory containing images, or a list of image paths or directories
+ containing images.
+ frame_type: Filters the images to analyze to only the ones with the given suffix
+ (e.g. setting `frame_type`=".png" will only analyze ".png" images). The
+ default behavior analyzes all ".jpg", ".jpeg" and ".png" images.
+ output_dir: The directory where the predictions will be stored.
+ shuffle: The shuffle for which to run image analysis.
+ trainingsetindex: The trainingsetindex for which to run image analysis.
+ snapshot_index: The index of the snapshot to use. Loaded from the project
+ configuration file if None.
+ detector_snapshot_index: For top-down models only. The index of the detector
+ snapshot to use. Loaded from the project configuration file if None.
+ modelprefix: The model prefix used for the shuffle.
+ device: The device to use to run image analysis.
+ max_individuals: The maximum number of individuals to detect in each image. Set
+ to the number of individuals in the project if None.
+ save_as_csv: Whether to also save the predictions as a CSV file.
+ progress_bar: Whether to display a progress bar when running inference.
+ plotting: Whether to plot predictions on images.
+ pcutoff: The cutoff score when plotting pose predictions. Must be None or in
+ (0, 1). If None, the pcutoff is read from the project configuration file.
+ bbox_pcutoff: The cutoff score when plotting bounding box predictions. Must be
+ None or in (0, 1). If None, it is read from the project configuration file.
+ plot_skeleton: If a skeleton is defined in the model configuration file, whether
+ to plot the skeleton connecting the predicted bodyparts on the images.
+
+ Returns:
+ A dictionary mapping each image filename to the different types of predictions
+ for it (e.g. "bodyparts", "unique_bodyparts", "bboxes", "bbox_scores")
+ """
+ cfg = auxiliaryfunctions.read_config(config)
+ train_frac = cfg["TrainingFraction"][trainingsetindex]
+ model_folder = Path(cfg["project_path"]) / auxiliaryfunctions.get_model_folder(
+ train_frac,
+ shuffle,
+ cfg,
+ engine=Engine.PYTORCH,
+ modelprefix=modelprefix,
+ )
+ train_folder = model_folder / "train"
+
+ model_cfg_path = train_folder / Engine.PYTORCH.pose_cfg_name
+ model_cfg = config_utils.read_config_as_dict(model_cfg_path)
+ pose_task = Task(model_cfg["method"])
+
+ # get the snapshots to analyze images with
+ snapshot_index, detector_snapshot_index = parse_snapshot_index_for_analysis(
+ cfg, model_cfg, snapshot_index, detector_snapshot_index
+ )
+ snapshot = get_model_snapshots(snapshot_index, train_folder, pose_task)[0]
+ detector_snapshot = None
+ if detector_snapshot_index is not None:
+ detector_snapshot = get_model_snapshots(
+ detector_snapshot_index, train_folder, Task.DETECT
+ )[0]
+
+ predictions = analyze_image_folder(
+ model_cfg=model_cfg,
+ images=images,
+ snapshot_path=snapshot.path,
+ detector_path=None if detector_snapshot is None else detector_snapshot.path,
+ frame_type=frame_type,
+ device=device,
+ max_individuals=max_individuals,
+ progress_bar=progress_bar,
+ )
+
+ if len(predictions) == 0:
+ print(f"Found no images in {images}")
+ return {}
+
+ if output_dir is None:
+ images = list(predictions.keys())
+ output_dir = Path(images[0]).parent.resolve()
+ print(f"Setting output directory to {output_dir}")
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(exist_ok=True)
+
+ scorer = get_scorer_name(
+ cfg,
+ shuffle=shuffle,
+ train_fraction=train_frac,
+ snapshot_uid=get_scorer_uid(snapshot, detector_snapshot),
+ modelprefix=modelprefix,
+ )
+ individuals = model_cfg["metadata"]["individuals"]
+ if max_individuals is not None:
+ individuals = [f"individual{i}" for i in range(max_individuals)]
+
+ df_predictions = build_predictions_dataframe(
+ scorer=scorer,
+ predictions=predictions,
+ parameters=data.PoseDatasetParameters(
+ bodyparts=model_cfg["metadata"]["bodyparts"],
+ unique_bpts=model_cfg["metadata"]["unique_bodyparts"],
+ individuals=individuals,
+ ),
+ image_name_to_index=None,
+ )
+
+ output_filepath = output_dir / f"image_predictions_{scorer}.h5"
+ print(f"Saving predictions to {output_filepath}")
+
+ df_predictions.to_hdf(output_filepath, key="predictions")
+ if save_as_csv:
+ print(f"Saving CSV as {output_filepath}")
+ df_predictions.to_csv(output_filepath.with_suffix(".csv"))
+
+ if plotting:
+ plot_dir = output_dir / f"LabeledImages_{scorer}"
+ plot_dir.mkdir(exist_ok=True)
+
+ mode = plotting if isinstance(plotting, str) else "bodypart"
+
+ bodyparts = model_cfg["metadata"]["bodyparts"]
+ skeleton = None
+ if plot_skeleton and len(cfg.get("skeleton", [])) > 0:
+ skeleton = [
+ (bodyparts.index(bpt_0), bodyparts.index(bpt_1))
+ for bpt_0, bpt_1 in cfg["skeleton"]
+ ]
+
+ if pcutoff is None:
+ pcutoff = cfg.get("pcutoff", 0.6)
+ if bbox_pcutoff is None:
+ bbox_pcutoff = cfg.get("bbox_pcutoff", 0.6)
+
+ visualization.create_labeled_images(
+ predictions=predictions,
+ out_folder=plot_dir,
+ pcutoff=pcutoff,
+ bboxes_pcutoff=bbox_pcutoff,
+ mode=mode,
+ cmap=cfg.get("colormap", "rainbow"),
+ dot_size=cfg.get("dotsize", 12),
+ alpha_value=cfg.get("alphavalue", 12),
+ skeleton=skeleton,
+ skeleton_color=cfg.get("skeleton_color"),
+ )
+
+ return predictions
+
+
+def analyze_image_folder(
+ model_cfg: str | Path | dict,
+ images: str | Path | list[str] | list[Path],
+ snapshot_path: str | Path,
+ detector_path: str | Path | None = None,
+ frame_type: str | None = None,
+ device: str | None = None,
+ max_individuals: int | None = None,
+ progress_bar: bool = True,
+) -> dict[str, dict[str, np.ndarray | np.ndarray]]:
+ """Runs pose inference on a folder of images and returns the predictions
+
+ Args:
+ model_cfg: The model config (or its path) used to analyze the images.
+ images: The images to analyze. Can either be a directory containing images, or
+ a list of paths of images.
+ snapshot_path: The path of the snapshot to use to analyze the images.
+ detector_path: The path of the detector snapshot to use to analyze the images,
+ if a top-down model was used.
+ frame_type: Filters the images to analyze to only the ones with the given suffix
+ (e.g. setting `frame_type`=".png" will only analyze ".png" images). The
+ default behavior analyzes all ".jpg", ".jpeg" and ".png" images.
+ device: The device to use to run image analysis.
+ max_individuals: The maximum number of individuals to detect in each image. Set
+ to the number of individuals in the project if None.
+ progress_bar: Whether to display a progress bar when running inference.
+
+ Returns:
+ A dictionary mapping each image filename to the different types of predictions
+ for it (e.g. "bodyparts", "unique_bodyparts", "bboxes", "bbox_scores")
+
+ Raises:
+ ValueError: if the pose model is a top-down model but no detector path is given
+ """
+ if not isinstance(model_cfg, dict):
+ model_cfg = config_utils.read_config_as_dict(model_cfg)
+
+ pose_task = Task(model_cfg["method"])
+ if pose_task == Task.TOP_DOWN and detector_path is None:
+ raise ValueError(
+ "A detector path must be specified for image analysis using top-down models"
+ f" Please specify the `detector_path` parameter."
+ )
+
+ if max_individuals is None:
+ max_individuals = len(model_cfg["metadata"]["individuals"])
+
+ if device is None:
+ device = resolve_device(model_cfg)
+
+ pose_runner = get_pose_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=snapshot_path,
+ device=device,
+ max_individuals=max_individuals,
+ )
+
+ image_suffixes = ".png", ".jpg", ".jpeg"
+ if frame_type is not None:
+ image_suffixes = (frame_type, )
+
+ image_paths = parse_images_and_image_folders(images, image_suffixes)
+ pose_inputs = image_paths
+ if detector_path is not None:
+ logging.info(f"Running object detection with {detector_path}")
+ detector_runner = get_detector_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=detector_path,
+ device=device,
+ max_individuals=max_individuals,
+ )
+
+ detector_image_paths = image_paths
+ if progress_bar:
+ detector_image_paths = tqdm(detector_image_paths)
+ bbox_predictions = detector_runner.inference(images=detector_image_paths)
+ pose_inputs = list(zip(image_paths, bbox_predictions))
+
+ logging.info(f"Running pose estimation with {detector_path}")
+
+ if progress_bar:
+ pose_inputs = tqdm(pose_inputs)
+
+ predictions = pose_runner.inference(pose_inputs)
+
+ return {
+ image_path: image_predictions
+ for image_path, image_predictions in zip(image_paths, predictions)
+ }
+
+
+def plot_images_coco(
+ model_cfg: str | Path | dict,
+ image_folder: str | Path,
+ snapshot_path: str | Path,
+ out_path: str = "test_images",
+ data_json_path: str = "",
+ detector_path: str | Path | None = None,
+ device: str | None = None,
+ max_individuals: int | None = None,
+) -> list[dict]:
+ """
+ Runs pose inference on a folder of images from a COCO dataset, and plots all
+ predicted keypoints and bounding boxes
+
+ Args:
+ model_cfg: The model config (or its path) used to analyze the images.
+ image_folder: The path to the folder containing the images to analyze.
+ snapshot_path: The path of the snapshot to use to analyze the images.
+ out_path: The path of the folder where images should be output.
+ data_json_path: The path to the JSON file containing ground truth data.
+ detector_path: The path of the detector snapshot to use to analyze the images,
+ if a top-down model was used.
+ device: The device on which to run image inference
+ max_individuals: The maximum number of individuals to detect in an image.
+
+ Returns:
+ A list of dictionaries containing predictions made on each image.
+
+ Raises:
+ ValueError: if a top-down model configuration is given but detector_path is None
+ """
+ with open(data_json_path, "r") as f:
+ obj = json.load(f)
+
+ coco_images = obj["images"]
+ coco_annotations = obj["annotations"]
+
+ image_name_to_id = {}
+ for image in coco_images:
+ # only works with relative path as a test image can be in a different folder
+ image_name = image["file_name"].split(os.sep)[-1]
+ image_name_to_id[image_name] = image["id"]
+
+ image_id_to_annotations = defaultdict(list)
+ image_ids = list(image_name_to_id.values())
+ for annotation in coco_annotations:
+ image_id = annotation["image_id"]
+ if annotation["image_id"] in image_ids:
+ image_id_to_annotations[image_id].append(annotation)
+
+ # need to support more image types
+ images_in_folder = glob.glob(str(Path(image_folder) / "*.png"))
+ corresponded_images = []
+ for image in images_in_folder:
+ image_path = image
+ image_name = image.split(os.sep)[-1]
+ if image_name in image_name_to_id:
+ corresponded_images.append(image_path)
+
+ images = corresponded_images
+
+ predictions = analyze_image_folder(
+ model_cfg=model_cfg,
+ images=images,
+ snapshot_path=snapshot_path,
+ detector_path=detector_path,
+ device=device,
+ max_individuals=max_individuals,
+ progress_bar=True,
+ )
+
+ os.makedirs(out_path, exist_ok=True)
+
+ coco_format_predictions = []
+ for image_path, prediction in predictions.items():
+ image_name = image_path.split(os.sep)[-1]
+ coco_prediction = dict(
+ image_id=image_name_to_id[image_name],
+ gt_annotations=image_id_to_annotations[image_name_to_id[image_name]],
+ file_name=image_path,
+ bodyparts=prediction["bodyparts"],
+ )
+ if "unique_bodyparts" in prediction:
+ coco_prediction["unique_bodyparts"] = prediction["unique_bodyparts"]
+ if "bboxes" in prediction:
+ coco_prediction["bboxes"] = prediction["bboxes"]
+ if "bbox_scores" in prediction:
+ coco_prediction["bbox_scores"] = prediction["bbox_scores"]
+
+ coco_format_predictions.append(coco_prediction)
+
+ frame = auxfun_videos.imread(str(image_path), mode="skimage")
+ fig, ax = plt.subplots()
+ ax.imshow(frame)
+
+ # TODO: color of keypoints are all red. Need to change to a different colormap
+ for pose in prediction["bodyparts"]:
+ x, y, confidence = pose[:, 0], pose[:, 1], pose[:, 2]
+ mask = confidence > 0.0
+ x = x[mask]
+ y = y[mask]
+ ax.scatter(x, y, color="red")
+
+ bboxes = prediction["bboxes"]
+ for bbox in bboxes:
+ # Draw bounding boxes around detected objects
+ xmin, ymin, w, h = bbox
+ rect = plt.Rectangle(
+ (xmin, ymin), w, h, fill=False, edgecolor="blue", linewidth=2
+ )
+
+ ax.add_patch(rect)
+ image_name = image_path.split("/")[-1]
+ fig.savefig(os.path.join(out_path, image_name))
+
+ return coco_format_predictions
+
+
+def parse_images_and_image_folders(
+ images: str | Path | list[str] | list[Path],
+ image_suffixes: tuple[str] = (".png", ".jpg", ".jpeg"),
+) -> list[str]:
+ """Parses image paths or directory paths into a single list of image paths.
+
+ Args:
+ images: Paths of images or folders containing images.
+ image_suffixes: Suffixes used for images.
+
+ Returns:
+ The images contained in the folders or directly the paths given as input
+ """
+ if isinstance(images, (str, Path)):
+ path = Path(images)
+ if path.is_dir():
+ return [str(img) for img in path.iterdir() if img.suffix in image_suffixes]
+
+ return [str(path)]
+
+ image_to_analyze = []
+ for file in images:
+ image_to_analyze += parse_images_and_image_folders(file)
+
+ return image_to_analyze
diff --git a/deeplabcut/pose_estimation_pytorch/apis/evaluation.py b/deeplabcut/pose_estimation_pytorch/apis/evaluation.py
new file mode 100755
index 0000000000..d3b15d1822
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/evaluation.py
@@ -0,0 +1,1030 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+from typing import Iterable
+
+import albumentations as A
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import deeplabcut.core.metrics as metrics
+import deeplabcut.pose_estimation_pytorch.apis.prune_paf_graph as prune_paf_graph
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch import utils
+from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ build_predictions_dataframe,
+ ensure_multianimal_df_format,
+ get_inference_runners,
+ get_model_snapshots,
+ get_scorer_name,
+ get_scorer_uid,
+ build_bboxes_dict_for_dataframe,
+)
+from deeplabcut.pose_estimation_pytorch.data import DLCLoader, Loader
+from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters
+from deeplabcut.pose_estimation_pytorch.runners import InferenceRunner
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.utils import auxfun_videos, auxiliaryfunctions
+from deeplabcut.utils.visualization import (
+ create_minimal_figure,
+ erase_artists,
+ get_cmap,
+ make_multianimal_labeled_image,
+ plot_evaluation_results,
+ save_labeled_frame,
+)
+
+
+def predict(
+ pose_runner: InferenceRunner,
+ loader: Loader,
+ mode: str,
+ detector_runner: InferenceRunner | None = None,
+) -> dict[str, dict[str, np.ndarray]]:
+ """Predicts poses on data contained in a loader
+
+ Args:
+ pose_runner: The runner to use for pose estimation
+ loader: The loader containing the data to predict poses on
+ mode: {"train", "test"} The mode to predict on
+ detector_runner: If the loader's `pose_task` is "TD", a detector runner can be
+ given to detect individuals in the images. If no detector is given, ground
+ truth bounding boxes will be used to crop individuals before pose estimation
+
+ Returns:
+ The paths of images for which predictions were computed mapping to the
+ different predictions made by each model head
+ """
+ image_paths = loader.image_filenames(mode)
+ context = None
+
+ if loader.pose_task == Task.TOP_DOWN:
+ # Get bounding boxes for context
+ if detector_runner is not None:
+ bbox_predictions = detector_runner.inference(images=tqdm(image_paths))
+ context = bbox_predictions
+ else:
+ ground_truth_bboxes = loader.ground_truth_bboxes(mode=mode)
+ context = [
+ {"bboxes": ground_truth_bboxes[image]["bboxes"]}
+ for image in image_paths
+ ]
+
+ images_with_context = image_paths
+ if context is not None:
+ if len(context) != len(image_paths):
+ raise ValueError(
+ f"Missing context for some images: {len(context)} != {len(image_paths)}"
+ )
+ images_with_context = list(zip(image_paths, context))
+
+ predictions = pose_runner.inference(images=tqdm(images_with_context))
+ return {
+ image_path: image_predictions
+ for image_path, image_predictions in zip(image_paths, predictions)
+ }
+
+
+def evaluate(
+ pose_runner: InferenceRunner,
+ loader: Loader,
+ mode: str,
+ detector_runner: InferenceRunner | None = None,
+ parameters: PoseDatasetParameters | None = None,
+ comparison_bodyparts: str | list[str] | None = None,
+ per_keypoint_evaluation: bool = False,
+ pcutoff: float | list[float] = 0.6,
+) -> tuple[dict[str, float], dict[str, dict[str, np.ndarray]]]:
+ """
+ Args:
+ pose_runner: The runner for pose estimation
+ loader: The loader containing the data to evaluate
+ mode: Either 'train' or 'test'
+ detector_runner: If the loader's `pose_task` is "TD", a detector can be given to
+ compute bounding boxes for pose estimation. If no detector is given, ground
+ truth bounding boxes are used.
+ parameters: PoseDatasetParameters to use. If None, the parameters will be
+ obtained from the given Loader. This can be used to change the names of
+ bodyparts, e.g. when a model is trained with memory replay.
+ comparison_bodyparts: A subset of the bodyparts for which to compute the
+ evaluation metrics. Passing "all" or None evaluates on all bodyparts.
+ per_keypoint_evaluation: Compute the train and test RMSE for each keypoint, and
+ save the results to a {model_name}-keypoint-results.csv in the
+ evaluation-results-pytorch folder.
+ pcutoff: Confidence threshold for RMSE computation. If a list is provided,
+ there should be one value for each bodypart and one value for each unique
+ bodypart (if there are any).
+
+ Returns:
+ A dict containing the evaluation results
+ A dict mapping the paths of images for which predictions were computed to the
+ different predictions made by each model head
+ """
+ predictions = predict(pose_runner, loader, mode, detector_runner=detector_runner)
+
+ # For models trained with memory-replay from SuperAnimal, keep project bodyparts
+ if weight_init_cfg := loader.model_cfg["train_settings"].get("weight_init"):
+ weight_init = WeightInitialization.from_dict(weight_init_cfg)
+ if weight_init.memory_replay:
+ for _, pred in predictions.items():
+ pred["bodyparts"] = pred["bodyparts"][:, weight_init.conversion_array]
+
+ if parameters is None:
+ parameters = loader.get_dataset_parameters()
+
+ gt_pose = loader.ground_truth_keypoints(mode)
+ pred_pose = {filename: pred["bodyparts"] for filename, pred in predictions.items()}
+ kpt_idx = _get_keypoints_to_use(parameters.bodyparts, comparison_bodyparts)
+
+ gt_unique, pred_unique, unique_idx = None, None, None
+ if parameters.num_unique_bpts >= 1:
+ gt_unique = loader.ground_truth_keypoints(mode, unique_bodypart=True)
+ pred_unique = {
+ filename: pred["unique_bodyparts"] for filename, pred in predictions.items()
+ }
+ unique_idx = _get_keypoints_to_use(parameters.unique_bpts, comparison_bodyparts)
+
+ # When `comparison_bodyparts` is used, check that the bodyparts used for evaluation
+ # make sense; If only unique bodyparts are being evaluated, set them as bodyparts
+ if kpt_idx is not None and unique_idx is not None:
+ if len(kpt_idx) == 0 and len(unique_idx) == 0:
+ unique_err = ""
+ if len(parameters.unique_bpts) > 0:
+ unique_err = f" and the unique_bodyparts are {parameters.unique_bpts}"
+ raise ValueError(
+ f"No bodyparts left when comparison_bodyparts={comparison_bodyparts}! "
+ f"The project bodyparts are {parameters.bodyparts}{unique_err}! Set "
+ f"comparison_bodyparts to `None` or `'all'` to evaluate on all of them,"
+ f" or select a subset of them to evaluate."
+ )
+ elif len(kpt_idx) == 0 and len(unique_idx) > 0:
+ gt_pose, pred_pose, kpt_idx = gt_unique, pred_unique, unique_idx
+ parameters = PoseDatasetParameters(
+ bodyparts=parameters.unique_bpts,
+ unique_bpts=[],
+ individuals=["animal"],
+ )
+ gt_unique, pred_unique, unique_idx = None, None, None
+
+ if kpt_idx is not None:
+ gt_pose = {img: kpts[:, kpt_idx] for img, kpts in gt_pose.items()}
+ pred_pose = {img: kpts[:, kpt_idx] for img, kpts in pred_pose.items()}
+
+ if unique_idx is not None:
+ gt_unique = {img: kpts[:, unique_idx] for img, kpts in gt_unique.items()}
+ pred_unique = {img: kpts[:, unique_idx] for img, kpts in pred_unique.items()}
+
+ bodyparts = _get_subset_bodyparts(parameters.bodyparts, comparison_bodyparts)
+ unique_bpts = _get_subset_bodyparts(parameters.unique_bpts, comparison_bodyparts)
+ _validate_pcutoff(bodyparts, unique_bpts, pcutoff)
+
+ results = metrics.compute_metrics(
+ gt_pose,
+ pred_pose,
+ single_animal=parameters.max_num_animals == 1,
+ pcutoff=pcutoff,
+ unique_bodypart_poses=pred_unique,
+ unique_bodypart_gt=gt_unique,
+ per_keypoint_rmse=per_keypoint_evaluation,
+ compute_detection_rmse=False,
+ )
+
+ if loader.model_cfg["metadata"]["with_identity"]:
+ pred_id_scores = {
+ filename: pred["identity_scores"] for filename, pred in predictions.items()
+ }
+ id_scores = metrics.compute_identity_scores(
+ individuals=parameters.individuals,
+ bodyparts=parameters.bodyparts,
+ predictions=pred_pose,
+ identity_scores=pred_id_scores,
+ ground_truth=gt_pose,
+ )
+ for name, score in id_scores.items():
+ results[f"id_head_{name}"] = score
+
+ # Updating poses to be aligned and padded
+ for image, pose in pred_pose.items():
+ predictions[image]["bodyparts"] = pose
+
+ return results, predictions
+
+
+def visualize_predictions(
+ predictions: dict,
+ ground_truth: dict,
+ output_dir: str | Path | None = None,
+ num_samples: int | None = None,
+ random_select: bool = False,
+ show_ground_truth: bool = True,
+ plot_bboxes: bool = True,
+) -> None:
+ """Visualize model predictions alongside ground truth keypoints.
+
+ This function processes keypoint predictions and ground truth data, applies
+ visibility masks, and generates visualization plots. It supports random or
+ sequential sampling of images for visualization.
+
+ Args:
+ predictions: Dictionary mapping image paths to prediction data.
+ Each prediction contains:
+ - bodyparts: array of shape [N, num_keypoints, 3] where 3 represents
+ (x, y, confidence)
+ - bboxes: array of shape [N, 4] for bounding boxes (optional)
+ - bbox_scores: array of shape [N,] for bbox confidences (optional)
+ ground_truth: Dictionary mapping image paths to ground truth keypoints.
+ Each value has shape [N, num_keypoints, 3] where 3 represents
+ (x, y, visibility)
+ output_dir: Path to save visualization outputs.
+ Defaults to "predictions_visualizations"
+ num_samples: Number of images to visualize. If None, processes all images
+ random_select: If True, randomly samples images; if False, uses first N images
+ show_ground_truth: If True, displays ground truth poses alongside predictions.
+ If False, only shows predictions but uses GT visibility mask
+ plot_bboxes: If True and the model is a top-down model, predicted bboxes will
+ be shown in the images as well
+ """
+ # Setup output directory
+ output_dir = Path(output_dir or "predictions_visualizations")
+ output_dir.mkdir(exist_ok=True)
+
+ # Select images to process
+ image_paths = list(predictions.keys())
+ if num_samples and num_samples < len(image_paths):
+ if random_select:
+ image_paths = np.random.choice(
+ image_paths, num_samples, replace=False
+ ).tolist()
+ else:
+ image_paths = image_paths[:num_samples]
+
+ # Process each selected image
+ for image_path in image_paths:
+ # Get prediction and ground truth data
+ pred_data = predictions[image_path]
+ gt_keypoints = ground_truth[image_path] # Shape: [N, num_keypoints, 3]
+
+ # Create visibility mask from first GT sample. This mask will be applied to all samples for consistency
+ vis_mask = gt_keypoints[0, :, 2] > 0
+
+ # Process ground truth keypoints if showing GT
+ if show_ground_truth:
+ visible_gt = []
+ for gt in gt_keypoints:
+ visible_points = gt[vis_mask, :2] # Keep only x,y for visible joints
+ visible_gt.append(visible_points)
+ visible_gt = np.stack(visible_gt) # Shape: [N, num_visible_joints, 2]
+ else:
+ visible_gt = None
+
+ # Process predicted keypoints
+ pred_keypoints = pred_data["bodyparts"] # Shape: [N, num_keypoints, 3]
+ visible_pred = []
+ for pred in pred_keypoints:
+ visible_points = pred[vis_mask] # Keep only visible joint predictions
+ visible_pred.append(visible_points)
+ visible_pred = np.stack(visible_pred) # Shape: [N, num_visible_joints, 3]
+
+ if plot_bboxes:
+ bboxes = predictions[image_path].get("bboxes", None)
+ bbox_scores = predictions[image_path].get("bbox_scores", None)
+ bounding_boxes = (
+ (bboxes, bbox_scores)
+ if bboxes is not None and bbox_scores is not None
+ else None
+ )
+ else:
+ bounding_boxes = None
+
+ # Generate and save visualization
+ try:
+ plot_gt_and_predictions(
+ image_path=image_path,
+ output_dir=output_dir,
+ gt_bodyparts=visible_gt,
+ pred_bodyparts=visible_pred,
+ bounding_boxes=bounding_boxes,
+ )
+ print(f"Successfully plotted predictions for {image_path}")
+ except Exception as e:
+ print(f"Error plotting predictions for {image_path}: {str(e)}")
+
+
+def plot_gt_and_predictions(
+ image_path: str | Path,
+ output_dir: str | Path,
+ gt_bodyparts: np.ndarray,
+ pred_bodyparts: np.ndarray,
+ gt_unique_bodyparts: np.ndarray | None = None,
+ pred_unique_bodyparts: np.ndarray | None = None,
+ mode: str = "bodypart",
+ colormap: str = "rainbow",
+ dot_size: int = 12,
+ alpha_value: float = 0.7,
+ p_cutoff: float | list[float] = 0.6,
+ bounding_boxes: tuple[np.ndarray, np.ndarray] | None = None,
+ bboxes_pcutoff: float = 0.6,
+ bounding_boxes_color: str = "auto",
+):
+ """Plot ground truth and predictions on an image.
+
+ Args:
+ image_path: Path to the image
+ gt_bodyparts: Ground truth keypoints array (num_animals, num_keypoints, 3)
+ pred_bodyparts: Predicted keypoints array (num_animals, num_keypoints, 3)
+ output_dir: Directory where labeled images will be saved
+ gt_unique_bodyparts: Ground truth unique bodyparts if any
+ pred_unique_bodyparts: Predicted unique bodyparts if any
+ mode: How to color the points ("bodypart" or "individual")
+ colormap: Matplotlib colormap name
+ dot_size: Size of the plotted points
+ alpha_value: Transparency of the points
+ p_cutoff: Confidence threshold for showing predictions. If a list is provided,
+ there should be one value for each bodypart and one value for each unique
+ bodypart (if there are any).
+ bounding_boxes: bounding boxes (top-left corner, size) and their respective
+ confidence levels,
+ bboxes_pcutoff: bounding boxes confidence cutoff threshold.
+ bounding_boxes_color: If plotting bounding boxes, this is the color that will be
+ used for bounding boxes. If set to "auto" (default value):
+ - if mode is "bodypart", the bbox color will be a default color
+ - if mode is "individual", each individual's color will be used for its
+ bounding box
+ """
+ # Ensure output directory exists
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Read the image
+ frame = auxfun_videos.imread(str(image_path), mode="skimage")
+ num_pred, num_keypoints = pred_bodyparts.shape[:2]
+
+ # Create figure and set dimensions
+ fig, ax = create_minimal_figure()
+ h, w, _ = np.shape(frame)
+ fig.set_size_inches(w / 100, h / 100)
+ ax.set_xlim(0, w)
+ ax.set_ylim(0, h)
+ ax.invert_yaxis()
+ ax.imshow(frame, "gray")
+
+ # Set up colors based on mode
+ if mode == "bodypart":
+ num_colors = num_keypoints
+ if pred_unique_bodyparts is not None:
+ num_colors += pred_unique_bodyparts.shape[1]
+ colors = get_cmap(num_colors, name=colormap)
+
+ predictions = pred_bodyparts.swapaxes(0, 1)
+ ground_truth = gt_bodyparts.swapaxes(0, 1)
+ elif mode == "individual":
+ colors = get_cmap(num_pred + 1, name=colormap)
+ predictions = pred_bodyparts
+ ground_truth = gt_bodyparts
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+
+ if bounding_boxes_color == "auto":
+ if mode == "bodypart":
+ bboxes_color = None
+ elif mode == "individual":
+ bboxes_color = get_cmap(num_pred + 1, name=colormap)
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+ else:
+ bboxes_color = bounding_boxes_color
+
+ # Plot regular bodyparts
+ ax = make_multianimal_labeled_image(
+ frame,
+ ground_truth,
+ predictions[:, :, :2],
+ predictions[:, :, 2:],
+ colors,
+ dot_size,
+ alpha_value,
+ p_cutoff,
+ ax=ax,
+ bounding_boxes=bounding_boxes,
+ bboxes_cutoff=bboxes_pcutoff,
+ bboxes_color=bboxes_color,
+ )
+
+ # Plot unique bodyparts if present
+ if pred_unique_bodyparts is not None and gt_unique_bodyparts is not None:
+ if mode == "bodypart":
+ unique_predictions = pred_unique_bodyparts.swapaxes(0, 1)
+ unique_ground_truth = gt_unique_bodyparts.swapaxes(0, 1)
+ else:
+ unique_predictions = pred_unique_bodyparts
+ unique_ground_truth = gt_unique_bodyparts
+
+ ax = make_multianimal_labeled_image(
+ frame,
+ unique_ground_truth,
+ unique_predictions[:, :, :2],
+ unique_predictions[:, :, 2:],
+ colors[num_keypoints:],
+ dot_size,
+ alpha_value,
+ p_cutoff,
+ ax=ax,
+ )
+
+ # Save the labeled image
+ save_labeled_frame(
+ fig,
+ str(image_path),
+ str(output_dir),
+ belongs_to_train=False,
+ )
+ erase_artists(ax)
+ plt.close()
+
+
+def evaluate_snapshot(
+ cfg: dict,
+ loader: DLCLoader,
+ snapshot: Snapshot,
+ scorer: str,
+ transform: A.Compose | None = None,
+ plotting: bool | str = False,
+ show_errors: bool = True,
+ comparison_bodyparts: str | list[str] | None = None,
+ per_keypoint_evaluation: bool = False,
+ detector_snapshot: Snapshot | None = None,
+ pcutoff: float | list[float] | dict[str, float] | None = None,
+) -> pd.DataFrame:
+ """Evaluates a snapshot.
+ The evaluation results are stored in the .h5 and .csv file under the subdirectory
+ 'evaluation_results'.
+
+ Args:
+ cfg: the content of the project's config file
+ loader: the loader for the shuffle to evaluate
+ snapshot: the snapshot to evaluate
+ scorer: the scorer name to use for the snapshot
+ transform: transformation pipeline for evaluation
+ ** Should normalise the data the same way it was normalised during training **
+ plotting: Plots the predictions on the train and test images. If provided it must
+ be either ``True``, ``False``, ``"bodypart"``, or ``"individual"``. Setting
+ to ``True`` defaults as ``"bodypart"`` for multi-animal projects.
+ show_errors: whether to compare predictions and ground truth
+ comparison_bodyparts: A subset of the bodyparts for which to compute the
+ evaluation metrics.
+ per_keypoint_evaluation: Compute the train and test RMSE for each keypoint, and
+ save the results to a {model_name}-keypoint-results.csv in the
+ evaluation-results-pytorch folder.
+ detector_snapshot: Only for TD models. If defined, evaluation metrics are
+ computed using the detections made by this snapshot
+ pcutoff: The cutoff to use for computing evaluation metrics. When `None`, the
+ cutoff will be loaded from the project config. If a list is provided, there
+ should be one value for each bodypart and one value for each unique bodypart
+ (if there are any). If a dict is provided, the keys should be bodyparts
+ mapping to pcutoff values for each bodypart. Bodyparts that are not defined
+ in the dict will have pcutoff set to 0.6.
+ """
+ head_type = loader.model_cfg["model"]["heads"]["bodypart"]["type"]
+ if head_type == "DLCRNetHead":
+ prune_paf_graph.benchmark_paf_graphs(
+ loader=loader, snapshot_path=snapshot.path, verbose=False,
+ )
+
+ parameters = loader.get_dataset_parameters()
+
+ detector_path = None
+ if detector_snapshot is not None:
+ detector_path = detector_snapshot.path
+
+ pose_runner, detector_runner = get_inference_runners(
+ model_config=loader.model_cfg,
+ snapshot_path=snapshot.path,
+ max_individuals=parameters.max_num_animals,
+ num_bodyparts=parameters.num_joints,
+ num_unique_bodyparts=parameters.num_unique_bpts,
+ with_identity=loader.model_cfg["metadata"]["with_identity"],
+ transform=transform,
+ detector_path=detector_path,
+ )
+
+ # For memory-replay SuperAnimal models, convert bodyparts to project bodyparts
+ if weight_init_cfg := loader.model_cfg["train_settings"].get("weight_init", None):
+ weight_init = WeightInitialization.from_dict(weight_init_cfg)
+ if weight_init.memory_replay:
+ bodyparts = weight_init.bodyparts
+ if bodyparts is None:
+ bodyparts = auxiliaryfunctions.get_bodyparts(cfg)
+
+ parameters = PoseDatasetParameters(
+ bodyparts=bodyparts,
+ unique_bpts=parameters.unique_bpts,
+ individuals=parameters.individuals,
+ )
+
+ # get the names of bodyparts on which the model is evaluated
+ eval_parameters = PoseDatasetParameters(
+ bodyparts=_get_subset_bodyparts(parameters.bodyparts, comparison_bodyparts),
+ unique_bpts=_get_subset_bodyparts(parameters.unique_bpts, comparison_bodyparts),
+ individuals=parameters.individuals,
+ )
+
+ if pcutoff is None:
+ pcutoff = cfg.get("pcutoff", 0.6)
+ elif isinstance(pcutoff, dict):
+ pcutoff = [
+ pcutoff.get(bpt, 0.6)
+ for bpt in eval_parameters.bodyparts + eval_parameters.unique_bpts
+ ]
+ _validate_pcutoff(parameters.bodyparts, parameters.unique_bpts, pcutoff)
+
+ predictions = {}
+ rmse_per_bodypart = {}
+ bounding_boxes = {}
+ scores = {
+ "%Training dataset": loader.train_fraction,
+ "Shuffle number": loader.shuffle,
+ "Training epochs": snapshot.epochs,
+ "Detector epochs (TD only)": (
+ -1 if detector_snapshot is None else detector_snapshot.epochs
+ ),
+ "pcutoff": (
+ ", ".join([str(v) for v in pcutoff])
+ if isinstance(pcutoff, list) else pcutoff
+ ),
+ }
+ for split in ["train", "test"]:
+ results, predictions_for_split = evaluate(
+ pose_runner=pose_runner,
+ loader=loader,
+ mode=split,
+ pcutoff=pcutoff,
+ detector_runner=detector_runner,
+ comparison_bodyparts=comparison_bodyparts,
+ per_keypoint_evaluation=per_keypoint_evaluation,
+ parameters=parameters,
+ )
+ if per_keypoint_evaluation:
+ rmse_per_bodypart[split] = _extract_rmse_per_bodypart(
+ results,
+ eval_parameters.bodyparts,
+ eval_parameters.unique_bpts,
+ )
+
+ df_split_predictions = build_predictions_dataframe(
+ scorer=scorer,
+ predictions=predictions_for_split,
+ parameters=eval_parameters,
+ image_name_to_index=image_to_dlc_df_index,
+ )
+ split_bounding_boxes = build_bboxes_dict_for_dataframe(
+ predictions=predictions_for_split,
+ image_name_to_index=image_to_dlc_df_index,
+ )
+ predictions[split] = df_split_predictions
+ bounding_boxes[split] = split_bounding_boxes
+ for k, v in results.items():
+ scores[f"{split} {k}"] = round(v, 2)
+
+ results_filename = f"{scorer}.h5"
+ df_predictions = pd.concat(predictions.values(), axis=0)
+ df_predictions = df_predictions.reindex(loader.df.index)
+ output_filename = loader.evaluation_folder / results_filename
+ output_filename.parent.mkdir(parents=True, exist_ok=True)
+ df_predictions.to_hdf(output_filename, key="df_with_missing")
+
+ df_scores = pd.DataFrame([scores]).set_index(
+ [
+ "%Training dataset",
+ "Shuffle number",
+ "Training epochs",
+ "Detector epochs (TD only)",
+ "pcutoff",
+ ]
+ )
+ scores_filepath = output_filename.with_suffix(".csv")
+ scores_filepath = scores_filepath.with_stem(scores_filepath.stem + "-results")
+ save_evaluation_results(df_scores, scores_filepath, show_errors, pcutoff)
+
+ if per_keypoint_evaluation:
+ rmse_per_bpt_path = output_filename.with_name(
+ output_filename.stem + "-keypoint-results.csv"
+ )
+ save_rmse_per_bodypart(rmse_per_bodypart, rmse_per_bpt_path, show_errors)
+
+ if plotting:
+ folder_name = f"LabeledImages_{scorer}"
+ folder_path = loader.evaluation_folder / folder_name
+ folder_path.mkdir(parents=True, exist_ok=True)
+ if isinstance(plotting, str):
+ plot_mode = plotting
+ else:
+ plot_mode = "bodypart"
+
+ df_ground_truth = ensure_multianimal_df_format(loader.df)
+
+ bboxes_cutoff = (
+ loader.model_cfg.get("detector", {})
+ .get("model", {})
+ .get("box_score_thresh", 0.6)
+ )
+
+ for mode in ["train", "test"]:
+ df_combined = predictions[mode].merge(
+ df_ground_truth, left_index=True, right_index=True
+ )
+ bboxes_split = bounding_boxes[mode]
+
+ plot_evaluation_results(
+ df_combined=df_combined,
+ project_root=cfg["project_path"],
+ scorer=cfg["scorer"],
+ model_name=scorer,
+ output_folder=str(folder_path),
+ in_train_set=mode == "train",
+ plot_unique_bodyparts=eval_parameters.num_unique_bpts > 0,
+ mode=plot_mode,
+ colormap=cfg["colormap"],
+ dot_size=cfg["dotsize"],
+ alpha_value=cfg["alphavalue"],
+ p_cutoff=cfg["pcutoff"],
+ bounding_boxes=bboxes_split,
+ bboxes_cutoff=bboxes_cutoff,
+ )
+
+ return df_predictions
+
+
+def evaluate_network(
+ config: str | Path,
+ shuffles: Iterable[int] = (1,),
+ trainingsetindex: int | str = 0,
+ snapshotindex: int | str | None = None,
+ device: str | None = None,
+ plotting: bool | str = False,
+ show_errors: bool = True,
+ transform: A.Compose = None,
+ comparison_bodyparts: str | list[str] | None = None,
+ per_keypoint_evaluation: bool = False,
+ modelprefix: str = "",
+ detector_snapshot_index: int | None = None,
+ pcutoff: float | list[float] | dict[str, float] | None = None,
+) -> None:
+ """Evaluates a snapshot.
+
+ The evaluation results are stored in the .h5 and .csv file under the subdirectory
+ 'evaluation_results'.
+
+ Args:
+ config: path to the project's config file
+ shuffles: Iterable of integers specifying the shuffle indices to evaluate.
+ trainingsetindex: Integer specifying which training set fraction to use.
+ Evaluates all fractions if set to "all"
+ snapshotindex: index (starting at 0) of the snapshot we want to load. To
+ evaluate the last one, use -1. To evaluate all snapshots, use "all". For
+ example if we have 3 models saved
+ - snapshot-0.pt
+ - snapshot-50.pt
+ - snapshot-100.pt
+ and we want to evaluate snapshot-50.pt, snapshotindex should be 1. If None,
+ the snapshotindex is loaded from the project configuration.
+ device: the device to run evaluation on
+ plotting: Plots the predictions on the train and test images. If provided it must
+ be either ``True``, ``False``, ``"bodypart"``, or ``"individual"``. Setting
+ to ``True`` defaults as ``"bodypart"`` for multi-animal projects.
+ show_errors: display train and test errors.
+ transform: transformation pipeline for evaluation
+ ** Should normalise the data the same way it was normalised during training **
+ comparison_bodyparts: A subset of the bodyparts for which to compute the
+ evaluation metrics.
+ per_keypoint_evaluation: Compute the train and test RMSE for each keypoint, and
+ save the results to a {model_name}-keypoint-results.csv in the
+ evaluation-results-pytorch folder.
+ modelprefix: directory containing the deeplabcut models to use when evaluating
+ the network. By default, they are assumed to exist in the project folder.
+ detector_snapshot_index: Only for TD models. If defined, uses the detector with
+ the given index for pose estimation.
+ pcutoff: The cutoff to use for computing evaluation metrics. When `None`, the
+ cutoff will be loaded from the project config. If a list is provided, there
+ should be one value for each bodypart and one value for each unique bodypart
+ (if there are any). If a dict is provided, the keys should be bodyparts
+ mapping to pcutoff values for each bodypart. Bodyparts that are not defined
+ in the dict will have pcutoff set to 0.6.
+
+ Examples:
+ If you want to evaluate on shuffle 1 without plotting predictions.
+
+ >>> import deeplabcut
+ >>> deeplabcut.evaluate_network(
+ >>> '/analysis/project/reaching-task/config.yaml', shuffles=[1],
+ >>> )
+
+ If you want to evaluate shuffles 0 and 1 and plot the predictions.
+
+ >>> deeplabcut.evaluate_network(
+ >>> '/analysis/project/reaching-task/config.yaml',
+ >>> shuffles=[0, 1],
+ >>> plotting=True,
+ >>> )
+
+ If you want to plot assemblies for a maDLC project
+
+ >>> deeplabcut.evaluate_network(
+ >>> '/analysis/project/reaching-task/config.yaml',
+ >>> shuffles=[1],
+ >>> plotting="individual",
+ >>> )
+ """
+ cfg = auxiliaryfunctions.read_config(config)
+
+ if isinstance(trainingsetindex, int):
+ train_set_indices = [trainingsetindex]
+ elif isinstance(trainingsetindex, str) and trainingsetindex.lower() == "all":
+ train_set_indices = list(range(len(cfg["TrainingFraction"])))
+ else:
+ raise ValueError(f"Invalid trainingsetindex: {trainingsetindex}")
+
+ if snapshotindex is None:
+ snapshotindex = cfg["snapshotindex"]
+
+ if detector_snapshot_index is None:
+ detector_snapshot_index = cfg["detector_snapshotindex"]
+
+ for train_set_index in train_set_indices:
+ for shuffle in shuffles:
+ loader = DLCLoader(
+ config=config,
+ shuffle=shuffle,
+ trainset_index=train_set_index,
+ modelprefix=modelprefix,
+ )
+ loader.evaluation_folder.mkdir(exist_ok=True, parents=True)
+
+ if device is not None:
+ loader.model_cfg["device"] = device
+ loader.model_cfg["device"] = utils.resolve_device(loader.model_cfg)
+
+ snapshots = get_model_snapshots(
+ snapshotindex,
+ model_folder=loader.model_folder,
+ task=loader.pose_task,
+ )
+
+ detector_snapshots = [None]
+ if loader.pose_task == Task.TOP_DOWN:
+ if detector_snapshot_index is not None:
+ det_snapshots = get_model_snapshots(
+ "all", loader.model_folder, Task.DETECT
+ )
+ if len(det_snapshots) == 0:
+ print(
+ "The detector_snapshot_index was set to "
+ f"{detector_snapshot_index} but no detector snapshots were "
+ f"found in {loader.model_folder}. Using ground truth "
+ "bounding boxes to compute metrics.\n"
+ "To analyze videos with a top-down model, you'll need to "
+ "train a detector!"
+ )
+ else:
+ detector_snapshots = get_model_snapshots(
+ detector_snapshot_index,
+ loader.model_folder,
+ Task.DETECT,
+ )
+ else:
+ print("Using GT bounding boxes to compute evaluation metrics")
+
+ for detector_snapshot in detector_snapshots:
+ for snapshot in snapshots:
+ scorer = get_scorer_name(
+ cfg=cfg,
+ shuffle=shuffle,
+ train_fraction=loader.train_fraction,
+ snapshot_uid=get_scorer_uid(snapshot, detector_snapshot),
+ modelprefix=modelprefix,
+ )
+ evaluate_snapshot(
+ loader=loader,
+ cfg=cfg,
+ scorer=scorer,
+ snapshot=snapshot,
+ transform=transform,
+ plotting=plotting,
+ show_errors=show_errors,
+ comparison_bodyparts=comparison_bodyparts,
+ per_keypoint_evaluation=per_keypoint_evaluation,
+ detector_snapshot=detector_snapshot,
+ pcutoff=pcutoff,
+ )
+
+
+def image_to_dlc_df_index(image: str) -> tuple[str, ...]:
+ """
+ Args:
+ image: the path of the image to map to a DLC index
+
+ Returns:
+ the image index to create a multi-animal DLC dataframe:
+ ("labeled-data", video_name, image_name)
+ """
+ image_path = Path(image)
+ if len(image_path.parts) >= 3 and image_path.parts[-3] == "labeled-data":
+ return Path(image_path).parts[-3:]
+
+ raise ValueError(f"Unexpected image filepath for a DLC project")
+
+
+def save_evaluation_results(
+ df_scores: pd.DataFrame, scores_path: Path, print_results: bool, pcutoff: float
+) -> None:
+ """
+ Saves the evaluation results to a CSV file. Adds the evaluation results for the
+ model to the combined results file, or creates it if it does not yet exist.
+
+ Args:
+ df_scores: the scores dataframe for a snapshot
+ scores_path: the path where the model scores CSV should be saved
+ print_results: whether to print evaluation results to the console
+ pcutoff: the pcutoff used to get the evaluation results
+ """
+ if print_results:
+ print(f"Evaluation results for {scores_path.name} (pcutoff: {pcutoff}):")
+ print(df_scores.iloc[0])
+
+ # Save scores file
+ df_scores.to_csv(scores_path)
+
+ # Update combined results
+ combined_scores_path = scores_path.parent.parent / "CombinedEvaluation-results.csv"
+ if combined_scores_path.exists():
+ df_existing_results = pd.read_csv(
+ combined_scores_path, index_col=[0, 1, 2, 3, 4]
+ )
+ df_scores = df_scores.combine_first(df_existing_results)
+
+ df_scores = df_scores.sort_index()
+ df_scores.to_csv(combined_scores_path)
+
+
+def save_rmse_per_bodypart(
+ rmse_per_bodypart: dict[str, dict[str, float]],
+ output_path: Path,
+ print_results: bool,
+) -> None:
+ """
+ Saves the evaluation results per bodypart to a CSV file.
+
+ Args:
+ rmse_per_bodypart: The scores dataframe for a snapshot
+ output_path: The path of the file where
+ print_results: Whether to print results to the console
+ """
+ index, data = [], []
+ if print_results:
+ print(f"Per-bodypart evaluation results ({output_path.stem}):")
+
+ for split, rmse_results in rmse_per_bodypart.items():
+ key = split.capitalize() + " error (px)"
+ index.append(key)
+ data.append(rmse_results)
+
+ if print_results:
+ print(f" {key}")
+ bpt_key_length = max([len(k) for k in rmse_results.keys()]) + 4
+ for k, v in rmse_results.items():
+ key = (k + ":").ljust(bpt_key_length)
+ print(f" {key}{v:3>.2f}px")
+
+ # Save scores file
+ df_rmse_per_bodypart = pd.DataFrame(data, index=index)
+ df_rmse_per_bodypart.to_csv(output_path)
+
+
+def _validate_pcutoff(
+ bodyparts: list[str],
+ unique_bpts: list[str],
+ pcutoff: float | list[float],
+) -> None:
+ """Checks that the given `pcutoff` value has the correct number of elements"""
+ if isinstance(pcutoff, (int, float)):
+ return
+
+ total_bodyparts = len(bodyparts) + len(unique_bpts)
+ if len(pcutoff) != total_bodyparts:
+ raise ValueError(
+ "When passing the pcutoff as a list, the length of the list should be "
+ "equal to the number of bodyparts and the number of unique bpts. "
+ f"Found a list containing {len(pcutoff)} elements, but there are "
+ f"{total_bodyparts} total bodyparts, which are {bodyparts + unique_bpts}."
+ )
+
+
+def _get_keypoints_to_use(
+ bodyparts: list[str],
+ bodypart_subset: str | list[str] | None,
+) -> list[int] | None:
+ """Computes the indices of the keypoints indices to keep based on the given subset.
+
+ Args:
+ bodyparts: The bodyparts predicted by the model.
+ bodypart_subset: The subset of bodyparts to keep. If None or "all", all
+ bodyparts are kept.
+
+ Returns:
+ None if all bodyparts should be kept, or bodyparts is an empty list. Otherwise,
+ returns a list containing the indices of the bodyparts to keep. If no bodyparts
+ should be kept, returns an empty list.
+ """
+ if len(bodyparts) == 0 or bodypart_subset is None or bodypart_subset == "all":
+ return None
+
+ if isinstance(bodypart_subset, str):
+ bodypart_subset = [bodypart_subset]
+
+ to_keep = set(bodypart_subset)
+ return [i for i, b in enumerate(bodyparts) if b in to_keep]
+
+
+def _get_subset_bodyparts(
+ bodyparts: list[str],
+ subset: str | list[str] | None,
+) -> list[str]:
+ """Gets a subset of bodyparts that were used.
+
+ Args:
+ bodyparts: The bodyparts output by the model.
+ subset: The subset of bodyparts to keep.
+
+ Returns:
+ The bodyparts that were used to evaluate the model.
+ """
+ if subset is None or subset == "all":
+ return bodyparts
+
+ if isinstance(subset, str):
+ subset = [subset]
+
+ to_keep = set(subset)
+ return [b for b in bodyparts if b in to_keep]
+
+
+def _extract_rmse_per_bodypart(
+ results: dict[str, float],
+ bodyparts: list[str],
+ unique_bodyparts: list[str],
+) -> dict[str, float]:
+ """Extracts the RMSE per bodypart metrics from the results dict
+
+ This method modifies the given dict in-place, removing all keys for RMSE per
+ bodypart or unique bodypart.
+
+ Args:
+ results: The results returned by the evaluation method.
+ bodyparts: The bodyparts defined for the project.
+ unique_bodyparts: The unique bodyparts defined for the project.
+
+ Returns:
+ The per-bodypart RMSE.
+ """
+ rmse_per_bodypart = {}
+ for bpt_idx, bpt in enumerate(bodyparts):
+ rmse = results.pop(f"rmse_keypoint_{bpt_idx}", None)
+ if rmse is not None:
+ rmse_per_bodypart[bpt] = rmse
+
+ for bpt_idx, bpt in enumerate(unique_bodyparts):
+ rmse = results.pop(f"rmse_unique_keypoint_{bpt_idx}", None)
+ if rmse is not None:
+ rmse_per_bodypart[bpt] = rmse
+
+ return rmse_per_bodypart
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str)
+ parser.add_argument("--modelprefix", type=str, default="")
+ parser.add_argument("--snapshotindex", type=int, default=49)
+ parser.add_argument("--plotting", type=bool, default=False)
+ parser.add_argument("--show_errors", type=bool, default=True)
+ args = parser.parse_args()
+ evaluate_network(
+ config=args.config,
+ modelprefix=args.modelprefix,
+ snapshotindex=args.snapshotindex,
+ plotting=args.plotting,
+ show_errors=args.show_errors,
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/apis/export.py b/deeplabcut/pose_estimation_pytorch/apis/export.py
new file mode 100644
index 0000000000..09de76f369
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/export.py
@@ -0,0 +1,194 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Code to export DeepLabCut models for DLCLive inference"""
+import copy
+from pathlib import Path
+
+import torch
+
+import deeplabcut.pose_estimation_pytorch.apis.utils as utils
+import deeplabcut.pose_estimation_pytorch.data as dlc3_data
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+def export_model(
+ config: str | Path,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ snapshotindex: int | None = None,
+ detector_snapshot_index: int | None = None,
+ iteration: int | None = None,
+ overwrite: bool = False,
+ wipe_paths: bool = False,
+ without_detector: bool = False,
+ modelprefix: str | None = None,
+) -> None:
+ """Export DeepLabCut models for live inference.
+
+ Saves the pytorch_config.yaml configuration, snapshot files, of the model to a
+ directory named exported-models-pytorch within the project directory.
+
+ Args:
+ config: Path of the project configuration file
+ shuffle : The shuffle of the model to export.
+ trainingsetindex: The index of the training fraction for the model you wish to
+ export.
+ snapshotindex: The snapshot index for the weights you wish to export. If None,
+ uses the snapshotindex as defined in ``config.yaml``.
+ detector_snapshot_index: Only for TD models. If defined, uses the detector with
+ the given index for pose estimation. If None, uses the snapshotindex as
+ defined in the project ``config.yaml``.
+ iteration: The project iteration (active learning loop) you wish to export. If
+ None, the iteration listed in the project config file is used.
+ overwrite : bool, optional
+ If the model you wish to export has already been exported, whether to
+ overwrite. default = False
+ wipe_paths : bool, optional
+ Removes the actual path of your project and the init_weights from the
+ ``pytorch_config.yaml``.
+ without_detector: bool, optional
+ Exports top-down models without the detector.
+ modelprefix: Directory containing the deeplabcut models to use when evaluating
+ the network. By default, the models are assumed to exist in the project
+ folder.
+
+ Raises:
+ ValueError: If no snapshots could be found for the shuffle.
+ ValueError: If a top-down model is exported but no detector snapshots are found.
+
+ Examples:
+ Export the last stored snapshot for model trained with shuffle 3:
+ >>> import deeplabcut
+ >>> deeplabcut.export_model(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> shuffle=3,
+ >>> snapshotindex=-1,
+ >>> )
+ """
+ cfg = af.read_config(str(config))
+ if iteration is not None:
+ cfg["iteration"] = iteration
+
+ loader = dlc3_data.DLCLoader(
+ config=cfg,
+ trainset_index=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix="" if modelprefix is None else modelprefix,
+ )
+
+ if snapshotindex is None:
+ snapshotindex = loader.project_cfg["snapshotindex"]
+ snapshots = utils.get_model_snapshots(
+ snapshotindex, loader.model_folder, loader.pose_task
+ )
+
+ if len(snapshots) == 0:
+ raise ValueError(
+ f"Could not find any snapshots to export in ``{loader.model_folder}`` for "
+ f"``snapshotindex={snapshotindex}``."
+ )
+
+ detector_snapshots = [None]
+ if loader.pose_task == Task.TOP_DOWN and not without_detector:
+ if detector_snapshot_index is None:
+ detector_snapshot_index = loader.project_cfg["detector_snapshotindex"]
+ detector_snapshots = utils.get_model_snapshots(
+ detector_snapshot_index, loader.model_folder, Task.DETECT
+ )
+
+ if len(detector_snapshots) == 0:
+ raise ValueError(
+ "Attempting to export a top-down pose estimation model but no detector "
+ f"snapshots were found in ``{loader.model_folder}`` for "
+ f"``detector_snapshot_index={detector_snapshot_index}``. You must "
+ f"export a detector snapshot with a top-down pose estimation model."
+ )
+
+ export_folder_name = get_export_folder_name(loader)
+ export_dir = loader.project_path / "exported-models-pytorch" / export_folder_name
+ export_dir.mkdir(exist_ok=True, parents=True)
+
+ load_kwargs = dict(map_location="cpu", weights_only=True)
+ for det_snapshot in detector_snapshots:
+ detector_weights = None
+ if det_snapshot is not None:
+ detector_weights = torch.load(det_snapshot.path, **load_kwargs)["model"]
+
+ for snapshot in snapshots:
+ export_filename = get_export_filename(loader, snapshot, det_snapshot)
+ export_path = export_dir / export_filename
+ if export_path.exists() and not overwrite:
+ continue
+
+ model_cfg = copy.deepcopy(loader.model_cfg)
+ if wipe_paths:
+ wipe_paths_from_model_config(model_cfg)
+
+ pose_weights = torch.load(snapshot.path, **load_kwargs)["model"]
+ export_dict = dict(config=model_cfg, pose=pose_weights)
+ if detector_weights is not None:
+ export_dict["detector"] = detector_weights
+
+ torch.save(export_dict, export_path)
+
+
+def get_export_folder_name(loader: dlc3_data.DLCLoader) -> str:
+ """
+ Args:
+ loader: The loader for the shuffle for which we want to export models.
+
+ Returns:
+ The name of the folder in which exported models should be placed for a shuffle.
+ """
+ return (
+ f"DLC_{loader.project_cfg['Task']}_{loader.model_cfg['net_type']}_"
+ f"iteration-{loader.project_cfg['iteration']}_shuffle-{loader.shuffle}"
+ )
+
+
+def get_export_filename(
+ loader: dlc3_data.DLCLoader,
+ snapshot: Snapshot,
+ detector_snapshot: Snapshot | None = None,
+) -> str:
+ """
+ Args:
+ loader: The loader for the shuffle for which we want to export models.
+ snapshot: The pose model snapshot to export.
+ detector_snapshot: The detector snapshot to export, for top-down models.
+
+ Returns:
+ The name of the file in which the exported model should be stored.
+ """
+ export_filename = get_export_folder_name(loader)
+ if detector_snapshot is not None:
+ export_filename += "_snapshot-detector" + detector_snapshot.uid()
+ export_filename += "_snapshot-" + snapshot.uid()
+ return export_filename + ".pt"
+
+
+def wipe_paths_from_model_config(model_cfg: dict) -> None:
+ """
+ Removes all paths from the contents of the ``pytorch_config`` file.
+
+ Args:
+ model_cfg: The model configuration to wipe.
+ """
+ model_cfg["metadata"]["project_path"] = ""
+ model_cfg["metadata"]["pose_config_path"] = ""
+ if "weight_init" in model_cfg["train_settings"]:
+ model_cfg["train_settings"]["weight_init"] = None
+ if "resume_training_from" in model_cfg:
+ model_cfg["resume_training_from"] = None
+ if "resume_training_from" in model_cfg.get("detector", {}):
+ model_cfg["detector"]["resume_training_from"] = None
diff --git a/deeplabcut/pose_estimation_pytorch/apis/prune_paf_graph.py b/deeplabcut/pose_estimation_pytorch/apis/prune_paf_graph.py
new file mode 100644
index 0000000000..a31a833982
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/prune_paf_graph.py
@@ -0,0 +1,294 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from collections import defaultdict
+from pathlib import Path
+
+import networkx as nx
+import numpy as np
+import torch
+from tqdm import tqdm
+
+import deeplabcut.core.metrics as metrics
+import deeplabcut.pose_estimation_pytorch.apis.utils as utils
+import deeplabcut.pose_estimation_pytorch.data as data
+import deeplabcut.pose_estimation_pytorch.models.predictors as predictors
+import deeplabcut.utils.auxiliaryfunctions as auxiliaryfunctions
+from deeplabcut.core.crossvalutils import find_closest_neighbors
+from deeplabcut.pose_estimation_pytorch.models import PoseModel
+from deeplabcut.pose_estimation_pytorch.models.predictors.paf_predictor import Graph
+
+
+@torch.no_grad()
+def benchmark_paf_graphs(
+ loader: data.Loader,
+ snapshot_path: Path,
+ verbose: bool = False,
+ overwrite: bool = False,
+ update_config: bool = True,
+) -> list[dict]:
+ """Prunes the PAF graph to maximize performance
+
+ Args:
+ loader: The loader for the model to prune.
+ snapshot_path: The path to the snapshot with which to prune the model.
+ verbose: Verbose pruning of the model.
+ overwrite: Whether to overwrite the graph if it was already pruned.
+ update_config: Whether to update the model configuration with the pruned graph.
+
+ Returns:
+ A list of dictionaries containing results for each pruned graph.
+
+ If the graph was already pruned, a single element is returned with an
+ "edges_to_keep" key, containing the indices of edges to keep in the graph.
+
+ Otherwise, a list of graphs that were evaluated is returned, with "key_metric",
+ "edges_to_keep" and "metrics" keys. The list is sorted by "key_metric" (which
+ is pose mAP).
+ """
+ runner = utils.get_pose_inference_runner(loader.model_cfg, snapshot_path)
+ device = runner.device
+ preprocessor = runner.preprocessor
+ model = runner.model
+ predictor = model.heads.bodypart.predictor
+
+ # only benchmark the PAF graph if the PAF indices contain all edges
+ if not overwrite and len(predictor.edges_to_keep) < len(predictor.graph):
+ return [dict(edges_to_keep=predictor.edges_to_keep)]
+
+ model.to(device)
+ model.eval()
+
+ if not isinstance(predictor, predictors.PartAffinityFieldPredictor):
+ raise ValueError(f"Predictor should be a PartAffinityFieldPredictor.")
+
+ if verbose:
+ print("-------------------------------------------------")
+ print("Benchmarking different Part-Affinity Field Graphs")
+ print(" (1/3) Obtaining the best graph candidates")
+
+ gt_train = loader.ground_truth_keypoints("train")
+ best_paf_edges, _ = get_n_best_paf_graphs(
+ model,
+ gt_train,
+ preprocessor,
+ device,
+ predictor.graph,
+ n_graphs=10,
+ )
+
+ if verbose:
+ print(" (2/3) Running test inference")
+
+ gt_test = loader.ground_truth_keypoints("test")
+ images_test = [img_path for img_path in gt_test]
+
+ predictions = {graph_id: {} for graph_id in range(len(best_paf_edges))}
+ with torch.no_grad():
+ for image_path in tqdm(images_test):
+ image, _ = preprocessor(image_path, {})
+ outputs = model(image.to(device))
+ for graph_id, edges in enumerate(best_paf_edges):
+ predictor.set_paf_edges_to_keep(edges)
+ pred_pose = model.get_predictions(outputs)["bodypart"]["poses"]
+ predictions[graph_id][image_path] = pred_pose.cpu().numpy()[0]
+
+ if verbose:
+ print(" (3/3) Evaluating Graphs")
+
+ results = []
+ for graph_id, pred_pose in predictions.items():
+ edges_to_keep = [int(i) for i in best_paf_edges[graph_id]]
+ graph_metrics = metrics.compute_metrics(
+ gt_test,
+ pred_pose,
+ single_animal=False,
+ pcutoff=0.6,
+ )
+ results.append(
+ dict(
+ edges_to_keep=edges_to_keep,
+ key_metric=graph_metrics["mAP"],
+ metrics=graph_metrics,
+ )
+ )
+
+ if verbose:
+ print(" ---")
+ print(f" |Graph {graph_id}: {len(edges_to_keep)} edges")
+ print(f" | mAP: {graph_metrics['mAP']}")
+ print(f" | mAR: {graph_metrics['mAR']}")
+ print(f" | edges: {edges_to_keep}")
+ print()
+
+ results = list(sorted(results, key=lambda r: 1 - r["key_metric"]))
+
+ if update_config and len(results) > 0:
+ best_results = results[0]
+ best_edges = best_results["edges_to_keep"]
+ graph_metrics = best_results["metrics"]
+
+ if verbose:
+ print("Selecting the following Graph")
+ print(60 * "-")
+ print(f"|Graph with {len(best_edges)} edges")
+ print(f"| mAP: {graph_metrics['mAP']}")
+ print(f"| mAR: {graph_metrics['mAR']}")
+ print(f"| edges: {best_edges}")
+ print()
+
+ # update the edges to keep in the PyTorch configuration file
+ loader.update_model_cfg(
+ {"model.heads.bodypart.predictor.edges_to_keep": best_edges}
+ )
+
+ # update the edges indices
+ test_config = loader.model_folder.parent / "test" / "pose_cfg.yaml"
+ auxiliaryfunctions.edit_config(str(test_config), dict(paf_best=best_edges))
+
+ return results
+
+
+def _calc_separability(
+ vals_left: np.ndarray,
+ vals_right: np.ndarray,
+ n_bins: int = 101,
+ metric: str = "jeffries",
+ max_sensitivity: bool = False,
+) -> tuple[float, float]:
+ if metric not in ("jeffries", "auc"):
+ raise ValueError("`metric` should be either 'jeffries' or 'auc'.")
+
+ bins = np.linspace(0, 1, n_bins)
+ hist_left = np.histogram(vals_left, bins=bins)[0]
+ hist_left = hist_left / hist_left.sum()
+ hist_right = np.histogram(vals_right, bins=bins)[0]
+ hist_right = hist_right / hist_right.sum()
+ tpr = np.cumsum(hist_right)
+ if metric == "jeffries":
+ sep = np.sqrt(
+ 2 * (1 - np.sum(np.sqrt(hist_left * hist_right)))
+ ) # Jeffries-Matusita distance
+ else:
+ sep = np.trapz(np.cumsum(hist_left), tpr)
+ if max_sensitivity:
+ threshold = bins[max(1, np.argmax(tpr > 0))]
+ else:
+ threshold = bins[np.argmin(1 - np.cumsum(hist_left) + tpr)]
+ return sep, threshold
+
+
+@torch.no_grad()
+def compute_within_between_paf_costs(
+ model: PoseModel,
+ ground_truth: dict[str, np.ndarray],
+ preprocessor: data.Preprocessor,
+ device: str,
+) -> tuple[defaultdict, defaultdict]:
+ predictor = model.heads.bodypart.predictor
+ images = [img_path for img_path in ground_truth]
+
+ within = defaultdict(list)
+ between = defaultdict(list)
+ for image_path in tqdm(images):
+ image, _ = preprocessor(image_path, {})
+ outputs = model(image.to(device))
+ preds = model.get_predictions(outputs)["bodypart"]["preds"][0]
+ gt_pose_with_vis = ground_truth[image_path].transpose((1, 0, 2))
+
+ # mask non-visible keypoints
+ gt_pose = gt_pose_with_vis[..., :2].copy()
+ gt_pose[gt_pose_with_vis[..., 2] <= 0] = np.nan
+
+ if np.isnan(gt_pose).all():
+ continue
+
+ coords_pred = preds["coordinates"][0]
+ costs_pred = preds["costs"]
+
+ # Get animal IDs and corresponding indices in the arrays of detections
+ lookup = dict()
+ for i, (coord_pred, coord_gt) in enumerate(zip(coords_pred, gt_pose)):
+ inds = np.flatnonzero(np.all(~np.isnan(coord_pred), axis=1))
+ inds_gt = np.flatnonzero(np.all(~np.isnan(coord_gt), axis=1))
+ if inds.size and inds_gt.size:
+ neighbors = find_closest_neighbors(
+ coord_gt[inds_gt], coord_pred[inds], k=3
+ )
+ found = neighbors != -1
+ lookup[i] = dict(zip(inds_gt[found], inds[neighbors[found]]))
+
+ for k, v in costs_pred.items():
+ paf = v["m1"]
+ mask_within = np.zeros(paf.shape, dtype=bool)
+ s, t = predictor.graph[k]
+ if s not in lookup or t not in lookup:
+ continue
+ lu_s = lookup[s]
+ lu_t = lookup[t]
+ common_id = set(lu_s).intersection(lu_t)
+ for id_ in common_id:
+ mask_within[lu_s[id_], lu_t[id_]] = True
+ within_vals = paf[mask_within]
+ between_vals = paf[~mask_within]
+ within[k].extend(within_vals)
+ between[k].extend(between_vals)
+
+ return within, between
+
+
+def get_n_best_paf_graphs(
+ model: PoseModel,
+ ground_truth: dict[str, np.ndarray],
+ preprocessor: data.Preprocessor,
+ device: str,
+ full_graph: Graph,
+ root_edges: list[int] | None = None,
+ n_graphs: int = 10,
+ metric: str = "auc",
+) -> tuple[list[list[int]], dict[int, float]]:
+ return_preds = model.heads.bodypart.predictor.return_preds
+ model.heads.bodypart.predictor.return_preds = True
+
+ within_train, between_train = compute_within_between_paf_costs(
+ model, ground_truth, preprocessor, device
+ )
+ existing_edges = list(set(k for k, v in within_train.items() if v))
+
+ scores, _ = zip(
+ *[
+ _calc_separability(between_train[n], within_train[n], metric=metric)
+ for n in existing_edges
+ ]
+ )
+
+ # Find minimal skeleton
+ G = nx.Graph()
+ for edge, score in zip(existing_edges, scores):
+ if np.isfinite(score):
+ G.add_edge(*full_graph[edge], weight=score)
+
+ order = np.asarray(existing_edges)[np.argsort(scores)[::-1]]
+ if root_edges is None:
+ root_edges = []
+ for edge in nx.maximum_spanning_edges(G, data=False):
+ root_edges.append(full_graph.index(sorted(edge)))
+
+ n_edges = len(existing_edges) - len(root_edges)
+ lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:]
+ order = order[np.isin(order, root_edges, invert=True)]
+ best_edges = [root_edges]
+ for length in lengths:
+ best_edges.append(root_edges + list(order[:length]))
+
+ model.heads.bodypart.predictor.return_preds = return_preds
+ return best_edges, dict(zip(existing_edges, scores))
diff --git a/deeplabcut/pose_estimation_pytorch/apis/tracking_dataset.py b/deeplabcut/pose_estimation_pytorch/apis/tracking_dataset.py
new file mode 100644
index 0000000000..d2b5d35d2d
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/tracking_dataset.py
@@ -0,0 +1,278 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Code to create tracking datasets for ReID model training"""
+from pathlib import Path
+
+from tqdm import tqdm
+
+import deeplabcut.pose_estimation_pytorch.apis.utils as utils
+import deeplabcut.pose_estimation_pytorch.data as data
+import deeplabcut.pose_estimation_pytorch.data.postprocessor as postprocessing
+import deeplabcut.pose_estimation_pytorch.models as models
+import deeplabcut.pose_estimation_pytorch.runners as runners
+import deeplabcut.pose_estimation_pytorch.runners.shelving as shelving
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.pose_estimation_pytorch.apis.videos import VideoIterator
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.pose_tracking_pytorch import create_triplets_dataset
+
+
+def build_feature_extraction_runner(
+ loader: data.Loader,
+ snapshot_path: str | Path,
+ device: str,
+ batch_size: int = 1,
+) -> runners.PoseInferenceRunner:
+ """Builds a runner to extract backbone features for poses of individuals
+
+ Args:
+ loader: The loader for the model to use.
+ snapshot_path: The path of the snapshot to use.
+ device: The device on which to run pose estimation.
+ batch_size: The batch size to run pose estimation with.
+
+ Returns:
+ A PoseInferenceRunner that will return features for extracted pose.
+ """
+ num_features = loader.model_cfg["model"]["backbone_output_channels"]
+ num_bodyparts = len(loader.model_cfg["metadata"]["bodyparts"])
+ top_down = loader.pose_task != Task.BOTTOM_UP
+ rescale_mode = postprocessing.RescaleAndOffset.Mode.KEYPOINT
+ if top_down:
+ rescale_mode = postprocessing.RescaleAndOffset.Mode.KEYPOINT_TD
+ data_cfg = loader.model_cfg["data"]["inference"]
+ crop_cfg = data_cfg.get("top_down_crop", {})
+ width, height = crop_cfg.get("width", 256), crop_cfg.get("height", 256)
+ preprocessor = data.build_top_down_preprocessor(
+ color_mode=loader.model_cfg["data"]["colormode"],
+ transform=data.build_transforms(data_cfg),
+ top_down_crop_size=(width, height),
+ top_down_crop_margin=crop_cfg.get("margin", 0),
+ )
+ else:
+ preprocessor = data.build_bottom_up_preprocessor(
+ loader.model_cfg["data"]["colormode"],
+ data.build_transforms(loader.model_cfg["data"]["inference"])
+ )
+
+ postprocessor = postprocessing.ComposePostprocessor(
+ [
+ postprocessing.PrepareBackboneFeatures(top_down=top_down),
+ postprocessing.ConcatenateOutputs(
+ keys_to_concatenate={
+ "bodyparts": ("bodypart", "poses"),
+ "features": ("backbone", "bodypart_features"),
+ },
+ empty_shapes={
+ "bodyparts": (num_bodyparts, 3),
+ "features": (num_bodyparts, num_features),
+ },
+ create_empty_outputs=True,
+ ),
+ postprocessing.RescaleAndOffset(["bodyparts"], rescale_mode),
+ ]
+ )
+
+ runner = runners.build_inference_runner(
+ task=loader.pose_task,
+ model=models.PoseModel.build(loader.model_cfg["model"]),
+ device=device,
+ snapshot_path=snapshot_path,
+ batch_size=batch_size,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ load_weights_only=loader.model_cfg["runner"].get("load_weights_only", None),
+ )
+ assert isinstance(runner, runners.PoseInferenceRunner), (
+ f"Failed to build inference runner: got type {type(runner)}"
+ )
+
+ # Set the model to output backbone features
+ runner.model.output_features = True
+
+ return runner
+
+
+def extract_features_for_video(
+ runner: runners.PoseInferenceRunner,
+ video: VideoIterator,
+ shelf_writer: shelving.FeatureShelfWriter,
+ detector_runner: runners.DetectorInferenceRunner | None = None,
+) -> None:
+ """Extracts backbone features for predicted keypoints in a video.
+
+ Args:
+ video: The video for which to extract backbone features.
+ runner: The inference runner with which to extract backbone features.
+ shelf_writer: The ShelfWriter used to extract features.
+ detector_runner: For top-down models, the detector to use to predict bboxes.
+ """
+ if detector_runner is not None:
+ print(f"Running detector with batch size {detector_runner.batch_size}")
+ bbox_predictions = detector_runner.inference(images=tqdm(video))
+ video.set_context(bbox_predictions)
+
+ shelf_writer.open()
+ runner.inference(tqdm(video), shelf_writer=shelf_writer)
+ shelf_writer.close()
+
+
+def create_tracking_dataset(
+ config: str,
+ videos: list[str] | list[Path],
+ track_method: str,
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ destfolder: str | None = None,
+ batch_size: int | None = None,
+ detector_batch_size: int | None = None,
+ cropping: list[int] | None = None,
+ modelprefix: str = "",
+ robust_nframes: bool = False,
+ n_triplets: int = 1000,
+) -> str:
+ """Creates a tracking dataset to train a ReID tracklet stitcher.
+
+ Args:
+ config: Full path of the config.yaml file for the project
+ videos: A str (or list of strings) containing the full paths to videos from
+ which to create the tracking dataset or a path to the directory, where all
+ the videos with same extension are stored.
+ track_method: Specifies the tracker used to generate the pose estimation data.
+ Must be either 'box', 'skeleton', or 'ellipse'.
+ videotype: Checks for the extension of the video in case the input to the video
+ is a directory. Only videos with this extension are analyzed. If left
+ unspecified, keeps videos with extensions ('avi', 'mp4', 'mov', 'mpeg',
+ 'mkv').
+ shuffle: An integer specifying the shuffle index of the training dataset used
+ for training the network.
+ trainingsetindex: Integer specifying which TrainingsetFraction to use.
+ destfolder: Specifies the destination folder for the tracking data. If ``None``,
+ the path of the video is used. Note that for subsequent analysis this
+ folder also needs to be passed.
+ batch_size: The batch size to use for inference. Takes the value from the
+ project config as a default.
+ detector_batch_size: The batch size to use for detector inference. Takes the
+ value from the project config as a default.
+ cropping: List of cropping coordinates as [x1, x2, y1, y2]. Note that the same
+ cropping parameters will then be used for all videos. If different video
+ crops are desired, run ``analyze_videos`` on individual videos with the
+ corresponding cropping coordinates.
+ modelprefix: Directory containing the deeplabcut models to use when evaluating
+ the network. By default, they are assumed to exist in the project folder.
+ robust_nframes: Evaluate a video's number of frames in a robust manner. This
+ option is slower (as the whole video is read frame-by-frame), but does not
+ rely on metadata, hence its robustness against file corruption.
+ n_triplets: The number of triplets to extract for the dataset.
+
+ Returns:
+ The scorer used to analyze the videos.
+ """
+ loader = data.DLCLoader(
+ config,
+ trainset_index=trainingsetindex,
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+ test_cfg_path = loader.model_folder.parent / "test" / "pose_cfg.yaml"
+ test_cfg = read_config_as_dict(test_cfg_path)
+
+ snapshot_index, detector_snapshot_index = utils.parse_snapshot_index_for_analysis(
+ loader.project_cfg, loader.model_cfg, None, None,
+ )
+ snapshot = utils.get_model_snapshots(
+ snapshot_index, loader.model_folder, loader.pose_task,
+ )[0]
+
+ if cropping is None and loader.project_cfg.get("cropping", False):
+ cropping = (
+ loader.project_cfg["x1"],
+ loader.project_cfg["x2"],
+ loader.project_cfg["y1"],
+ loader.project_cfg["y2"],
+ )
+
+ output_folder = None
+ if destfolder is not None and destfolder != "":
+ output_folder = Path(destfolder)
+
+ if batch_size is None:
+ batch_size = loader.project_cfg["batch_size"]
+
+ device = utils.resolve_device(loader.model_cfg)
+ runner = build_feature_extraction_runner(
+ loader, snapshot.path, device, batch_size=batch_size
+ )
+
+ detector_runner = None
+ detector_snapshot = None
+ if loader.pose_task == Task.TOP_DOWN:
+ if detector_batch_size is None:
+ detector_batch_size = loader.project_cfg.get("detector_batch_size", 1)
+
+ detector_snapshot = utils.get_model_snapshots(
+ detector_snapshot_index, loader.model_folder, Task.DETECT,
+ )[0]
+ detector_runner = utils.get_detector_inference_runner(
+ model_config=loader.model_cfg,
+ snapshot_path=detector_snapshot.path,
+ batch_size=detector_batch_size,
+ device=device,
+ )
+
+ dlc_scorer = utils.get_scorer_name(
+ loader.project_cfg,
+ shuffle,
+ loader.train_fraction,
+ snapshot_uid=utils.get_scorer_uid(snapshot, detector_snapshot),
+ modelprefix=modelprefix,
+ )
+
+ videos = utils.list_videos_in_folder(videos, videotype)
+ for video_path in videos:
+ print(f"Loading {video_path}")
+ video = VideoIterator(video_path, cropping=cropping)
+
+ nx, ny = video.dimensions
+ nframes = video.get_n_frames(robust=robust_nframes)
+ duration = video.calc_duration(robust=robust_nframes)
+ fps = video.fps
+ if robust_nframes:
+ fps = nframes / duration
+
+ print(f"Duration of video [s]: {duration:.2f}, recorded with {fps:.2f} fps!")
+ print(f"Overall # of frames: {nframes} found with (before cropping)")
+ print(f"Frame dimensions: {nx} x {ny}")
+
+ if output_folder is None:
+ output_folder = Path(video.video_path).parent
+ output_folder.mkdir(parents=True, exist_ok=True)
+ output_prefix = Path(video_path).stem + dlc_scorer
+ output_filepath = output_folder / f"{output_prefix}_bpt_features.pickle"
+
+ shelf_writer = shelving.FeatureShelfWriter(
+ test_cfg,
+ output_filepath,
+ num_frames=video.get_n_frames(robust=robust_nframes),
+ )
+ extract_features_for_video(
+ runner, video, shelf_writer, detector_runner=detector_runner
+ )
+
+ create_triplets_dataset(
+ videos,
+ dlc_scorer,
+ track_method,
+ n_triplets=n_triplets,
+ destfolder=destfolder,
+ )
+ return dlc_scorer
diff --git a/deeplabcut/pose_estimation_pytorch/apis/tracklets.py b/deeplabcut/pose_estimation_pytorch/apis/tracklets.py
new file mode 100644
index 0000000000..f9ec37cb11
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/tracklets.py
@@ -0,0 +1,299 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+import pickle
+import warnings
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import pandas as pd
+from scipy.optimize import linear_sum_assignment
+from scipy.special import softmax
+from tqdm import tqdm
+
+import deeplabcut.utils.auxiliaryfunctions as auxiliaryfunctions
+import deeplabcut.utils.auxfun_multianimal as auxfun_multianimal
+from deeplabcut.core import trackingutils
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.inferenceutils import Assembly
+from deeplabcut.pose_estimation_pytorch.apis.utils import (
+ get_scorer_name,
+ list_videos_in_folder,
+)
+
+
+def convert_detections2tracklets(
+ config: str,
+ videos: Union[str, List[str]],
+ videotype: Optional[str] = None,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ overwrite: bool = False,
+ destfolder: Optional[str] = None,
+ ignore_bodyparts: Optional[List[str]] = None,
+ inferencecfg: Optional[dict] = None,
+ modelprefix="",
+ greedy: bool = False, # TODO(niels): implement greedy assembly during video analysis
+ calibrate: bool = False, # TODO(niels): implement assembly calibration during video analysis
+ window_size: int = 0, # TODO(niels): implement window size selection for assembly during video analysis
+ identity_only=False,
+ track_method="",
+):
+ """TODO: Documentation, clean & remove code duplication (with analyze video)"""
+ cfg = auxiliaryfunctions.read_config(config)
+ inference_cfg = inferencecfg
+ track_method = auxfun_multianimal.get_track_method(cfg, track_method=track_method)
+
+ if len(cfg["multianimalbodyparts"]) == 1 and track_method != "box":
+ warnings.warn("Switching to `box` tracker for single point tracking...")
+ track_method = "box"
+ cfg["default_track_method"] = track_method
+ auxiliaryfunctions.write_config(config, cfg)
+
+ train_fraction = cfg["TrainingFraction"][trainingsetindex]
+ start_path = os.getcwd() # record cwd to return to this directory in the end
+
+ # TODO: add cropping as in video analysis!
+ # if cropping is not None:
+ # cfg['cropping']=True
+ # cfg['x1'],cfg['x2'],cfg['y1'],cfg['y2']=cropping
+ # print("Overwriting cropping parameters:", cropping)
+ # print("These are used for all videos, but won't be save to the cfg file.")
+
+ rel_model_dir = auxiliaryfunctions.get_model_folder(
+ train_fraction,
+ shuffle,
+ cfg,
+ modelprefix=modelprefix,
+ engine=Engine.PYTORCH,
+ )
+ model_dir = Path(cfg["project_path"]) / rel_model_dir
+ path_test_config = model_dir / "test" / "pose_cfg.yaml"
+ dlc_cfg = auxiliaryfunctions.read_plainconfig(str(path_test_config))
+
+ if "multi-animal" not in dlc_cfg["dataset_type"]:
+ raise ValueError("This function is only required for multianimal projects!")
+
+ if inference_cfg is None:
+ inference_cfg = auxfun_multianimal.read_inferencecfg(
+ model_dir / "test" / "inference_cfg.yaml", cfg
+ )
+ auxfun_multianimal.check_inferencecfg_sanity(cfg, inference_cfg)
+
+ if len(cfg["multianimalbodyparts"]) == 1 and track_method != "box":
+ warnings.warn("Switching to `box` tracker for single point tracking...")
+ track_method = "box"
+ # Also ensure `boundingboxslack` is greater than zero, otherwise overlap
+ # between trackers cannot be evaluated, resulting in empty tracklets.
+ inference_cfg["boundingboxslack"] = max(inference_cfg["boundingboxslack"], 40)
+
+ dlc_scorer = get_scorer_name(
+ cfg,
+ shuffle,
+ train_fraction,
+ snapshot_index=None,
+ detector_index=None,
+ modelprefix=modelprefix,
+ )
+
+ videos = list_videos_in_folder(videos, videotype)
+ if len(videos) == 0:
+ print(f"No videos were found in {videos}")
+ return
+
+ for video in videos:
+ print("Processing... ", video)
+ if destfolder is None:
+ output_path = video.parent
+ else:
+ output_path = Path(destfolder)
+ output_path.mkdir(exist_ok=True, parents=True)
+
+ video_name = video.stem
+
+ data_prefix = video_name + dlc_scorer
+ data_filename = output_path / (data_prefix + ".h5")
+ print(f"Loading From {data_filename}")
+ data, metadata = auxfun_multianimal.LoadFullMultiAnimalData(str(data_filename))
+ if track_method == "ellipse":
+ method = "el"
+ elif track_method == "box":
+ method = "bx"
+ else:
+ method = "sk"
+
+ track_filename = output_path / (data_prefix + f"_{method}.pickle")
+ if not overwrite and track_filename.exists():
+ # TODO: check if metadata are identical (same parameters!)
+ print(f"Tracklets already computed at {track_filename}")
+ print("Set overwrite = True to overwrite.")
+ else:
+ dlc_scorer = metadata["data"]["Scorer"]
+ joints = data["metadata"]["all_joints_names"]
+ n_joints = len(joints)
+
+ # TODO: adjust this for multi + unique bodyparts!
+ # this is only for multianimal parts and unique bodyparts as one (not one
+ # unique bodyparts guy tracked etc.)
+ bodypart_labels = [bpt for bpt in joints for _ in range(3)]
+ scorers = len(bodypart_labels) * [dlc_scorer]
+ xyl_value = int(len(bodypart_labels) / 3) * ["x", "y", "likelihood"]
+ df_index = pd.MultiIndex.from_arrays(
+ np.vstack([scorers, bodypart_labels, xyl_value]),
+ names=["scorer", "bodyparts", "coords"],
+ )
+
+ if track_method == "box":
+ mot_tracker = trackingutils.SORTBox(
+ inference_cfg["max_age"],
+ inference_cfg["min_hits"],
+ inference_cfg.get("oks_threshold", 0.3),
+ )
+ elif track_method == "skeleton":
+ mot_tracker = trackingutils.SORTSkeleton(
+ n_joints,
+ inference_cfg["max_age"],
+ inference_cfg["min_hits"],
+ inference_cfg.get("oks_threshold", 0.5),
+ )
+ else:
+ mot_tracker = trackingutils.SORTEllipse(
+ inference_cfg.get("max_age", 1),
+ inference_cfg.get("min_hits", 1),
+ inference_cfg.get("iou_threshold", 0.6),
+ )
+
+ tracklets = {}
+ multi_bpts = cfg["multianimalbodyparts"]
+
+ ass_filename = data_filename.with_stem(
+ data_filename.stem + "_assemblies"
+ ).with_suffix(".pickle")
+ if not ass_filename.exists():
+ raise FileNotFoundError(
+ f"Could not find the assembles file {ass_filename}. You're "
+ f"converting detections to tracklets using PyTorch, which "
+ "means the assemblies file must be created by the model when "
+ "analyzing the video!"
+ )
+
+ num_frames = data["metadata"]["nframes"]
+ ass = auxiliaryfunctions.read_pickle(ass_filename)
+
+ # Initialize storage of the 'single' individual track
+ if cfg["uniquebodyparts"]:
+ tracklets["single"] = {}
+ _single = {}
+ for index in range(num_frames):
+ single_detection = ass["single"].get(index)
+ if single_detection is None:
+ continue
+ _single[index] = np.asarray(single_detection)
+ tracklets["single"].update(_single)
+
+ pcutoff = inference_cfg.get("pcutoff")
+ if inference_cfg["topktoretain"] == 1:
+ tracklets[0] = {}
+ for index in tqdm(range(num_frames)):
+ assemblies = ass.get(index)
+ if assemblies is None:
+ continue
+
+ assembly = np.asarray(assemblies[0].data)
+ assembly[assembly[..., 2] < pcutoff] = np.nan
+ tracklets[0][index] = assembly
+ else:
+ keep = set(multi_bpts).difference(ignore_bodyparts or [])
+ keep_inds = sorted(multi_bpts.index(bpt) for bpt in keep)
+ for index in tqdm(range(num_frames)):
+ assemblies = ass.get(index)
+ if assemblies is None or len(assemblies) == 0:
+ continue
+
+ animals = np.stack([a for a in assemblies])
+ animals[np.any(animals[..., :3] < 0, axis=-1), :2] = np.nan
+ animals[animals[..., 2] < pcutoff, :2] = np.nan
+ animal_mask = ~np.all(np.isnan(animals[:, :, :2]), axis=(1, 2))
+ if ~np.any(animal_mask):
+ continue
+ animals = animals[animal_mask]
+
+ if identity_only:
+ # Optimal identity assignment based on soft voting
+ mat = np.zeros((len(animals), inference_cfg["topktoretain"]))
+ for row, animal_pose in enumerate(animals):
+ animal_pose = animal_pose[
+ ~np.isnan(animal_pose).any(axis=1)
+ ]
+ unique_ids, idx = np.unique(
+ animal_pose[:, 3], return_inverse=True
+ )
+ total_scores = np.bincount(idx, weights=animal_pose[:, 2])
+ softmax_id_scores = softmax(total_scores)
+ for pred_id, softmax_score in zip(
+ unique_ids.astype(int), softmax_id_scores
+ ):
+ mat[row, pred_id] = softmax_score
+
+ inds = linear_sum_assignment(mat, maximize=True)
+ trackers = np.c_[inds][:, ::-1]
+ else:
+ if track_method == "box":
+ xy = trackingutils.calc_bboxes_from_keypoints(
+ animals[:, keep_inds], inference_cfg["boundingboxslack"]
+ ) # TODO: get cropping parameters and utilize!
+ else:
+ xy = animals[:, keep_inds, :2]
+ trackers = mot_tracker.track(xy)
+
+ trackingutils.fill_tracklets(tracklets, trackers, animals, index)
+
+ tracklets["header"] = df_index
+ with open(track_filename, "wb") as f:
+ pickle.dump(tracklets, f, pickle.HIGHEST_PROTOCOL)
+
+ os.chdir(str(start_path))
+ print(
+ "The tracklets were created (i.e., under the hood "
+ "deeplabcut.convert_detections2tracklets was run). Now you can "
+ "'refine_tracklets' in the GUI, or run 'deeplabcut.stitch_tracklets'."
+ )
+
+
+def _conv_predictions_to_assemblies(
+ image_names: List[str], predictions: Dict[str, np.ndarray]
+) -> Dict[int, List[Assembly]]:
+ """
+ Converts predictions to an assemblies dictionary
+ predictions shape (num_animals, num_keypoints, 2 or 3)
+ """
+ assemblies = {}
+ if len(predictions) == 0:
+ return assemblies
+
+ for image_index, image_name in enumerate(image_names):
+ frame_predictions = predictions.get(image_name)
+ if frame_predictions is not None:
+ num_kpts, num_animals, pred_shape = frame_predictions.shape
+ kpt_lst = []
+ for i in range(num_animals):
+ animal_prediction = frame_predictions[:, i, :]
+ ass_prediction = np.ones((num_kpts, 4), dtype=frame_predictions.dtype)
+ ass_prediction[:, 3] = -ass_prediction[:, 3]
+ ass_prediction[:, :pred_shape] = animal_prediction.copy()
+ ass = Assembly.from_array(ass_prediction)
+ if len(ass) > 0:
+ kpt_lst.append(ass)
+
+ assemblies[image_index] = kpt_lst
+
+ return assemblies
diff --git a/deeplabcut/pose_estimation_pytorch/apis/training.py b/deeplabcut/pose_estimation_pytorch/apis/training.py
new file mode 100644
index 0000000000..d042086e10
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/training.py
@@ -0,0 +1,386 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import argparse
+import copy
+import logging
+from pathlib import Path
+
+import albumentations as A
+from torch.utils.data import DataLoader
+
+import deeplabcut.core.config as config_utils
+import deeplabcut.pose_estimation_pytorch.utils as utils
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.data import (
+ build_transforms,
+ COCOLoader,
+ DLCLoader,
+ Loader,
+)
+from deeplabcut.pose_estimation_pytorch.data.collate import COLLATE_FUNCTIONS
+from deeplabcut.pose_estimation_pytorch.models import DETECTORS, PoseModel
+from deeplabcut.pose_estimation_pytorch.modelzoo.memory_replay import (
+ prepare_memory_replay,
+)
+from deeplabcut.pose_estimation_pytorch.runners import build_training_runner
+from deeplabcut.pose_estimation_pytorch.runners.logger import (
+ destroy_file_logging,
+ LOGGER,
+ setup_file_logging,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+def train(
+ loader: Loader,
+ run_config: dict,
+ task: Task,
+ device: str | None = "cpu",
+ gpus: list[int] | None = None,
+ logger_config: dict | None = None,
+ snapshot_path: str | Path | None = None,
+ transform: A.BaseCompose | None = None,
+ inference_transform: A.BaseCompose | None = None,
+ max_snapshots_to_keep: int | None = None,
+ load_head_weights: bool = True,
+) -> None:
+ """Builds a model from a configuration and fits it to a dataset
+
+ Args:
+ loader: the loader containing the data to train on/validate with
+ run_config: the model and run configuration
+ task: the task to train the model for
+ device: the torch device to train on (such as "cpu", "cuda", "mps")
+ gpus: the list of GPU indices to use for multi-GPU training
+ logger_config: the configuration of a logger to use
+ snapshot_path: if continuing to train from a snapshot, the path containing the
+ weights to load
+ transform: if defined, overwrites the transform defined in the model config
+ inference_transform: if defined, overwrites the inference transform defined in
+ the model config
+ max_snapshots_to_keep: the maximum number of snapshots to store for each model
+ load_head_weights: When `snapshot_path` is not None and a pose model is being
+ trained, whether to load the head weights from the saved snapshot.
+ """
+ weight_init = None
+ pretrained = True
+
+ if weight_init_cfg := run_config["train_settings"].get("weight_init"):
+ weight_init = WeightInitialization.from_dict(weight_init_cfg)
+ pretrained = False
+
+ if task == Task.DETECT:
+ model = DETECTORS.build(
+ run_config["model"],
+ weight_init=weight_init,
+ pretrained=pretrained,
+ )
+
+ else:
+ model = PoseModel.build(
+ run_config["model"],
+ weight_init=weight_init,
+ pretrained_backbone=pretrained,
+ )
+
+ if max_snapshots_to_keep is not None:
+ run_config["runner"]["snapshots"]["max_snapshots"] = max_snapshots_to_keep
+
+ logger = None
+ if logger_config is not None:
+ logger = LOGGER.build(dict(**logger_config, model=model))
+ logger.log_config(run_config)
+
+ if device is None:
+ device = utils.resolve_device(run_config)
+ elif device == "auto":
+ run_config["device"] = device
+ device = utils.resolve_device(run_config)
+
+ if gpus is None:
+ gpus = run_config["runner"].get("gpus")
+
+ if device == "mps" and task == Task.DETECT:
+ device = "cpu" # FIXME: Cannot train detectors on MPS
+
+ if snapshot_path is None:
+ snapshot_path = run_config.get("resume_training_from")
+
+ model.to(device) # Move model before giving its parameters to the optimizer
+ runner = build_training_runner(
+ runner_config=run_config["runner"],
+ model_folder=loader.model_folder,
+ task=task,
+ model=model,
+ device=device,
+ gpus=gpus,
+ snapshot_path=snapshot_path,
+ load_head_weights=load_head_weights,
+ logger=logger,
+ )
+
+ if transform is None:
+ transform = build_transforms(run_config["data"]["train"])
+ if inference_transform is None:
+ inference_transform = build_transforms(run_config["data"]["inference"])
+
+ logging.info("Data Transforms:")
+ logging.info(f" Training: {transform}")
+ logging.info(f" Validation: {inference_transform}")
+
+ train_dataset = loader.create_dataset(transform=transform, mode="train", task=task)
+ valid_dataset = loader.create_dataset(
+ transform=inference_transform, mode="test", task=task
+ )
+
+ collate_fn = None
+ if collate_fn_cfg := run_config["data"]["train"].get("collate"):
+ collate_fn = COLLATE_FUNCTIONS.build(collate_fn_cfg)
+ logging.info(f"Using custom collate function: {collate_fn_cfg}")
+
+ batch_size = run_config["train_settings"]["batch_size"]
+ num_workers = run_config["train_settings"]["dataloader_workers"]
+ pin_memory = run_config["train_settings"]["dataloader_pin_memory"]
+ train_dataloader = DataLoader(
+ train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ )
+ valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False)
+
+ if (
+ loader.model_cfg["model"].get("freeze_bn_stats", False)
+ or loader.model_cfg["model"].get("backbone", {}).get("freeze_bn_stats", False)
+ or batch_size == 1
+ ):
+ logging.info(
+ "\nNote: According to your model configuration, you're training with batch "
+ "size 1 and/or ``freeze_bn_stats=false``. This is not an optimal setting "
+ "if you have powerful GPUs.\n"
+ "This is good for small batch sizes (e.g., when training on a CPU), where "
+ "you should keep ``freeze_bn_stats=true``.\n"
+ "If you're using a GPU to train, you can obtain faster performance by "
+ "setting a larger batch size (the biggest power of 2 where you don't get"
+ "a CUDA out-of-memory error, such as 8, 16, 32 or 64 depending on the "
+ "model, size of your images, and GPU memory) and ``freeze_bn_stats=false`` "
+ "for the backbone of your model. \n"
+ "This also allows you to increase the learning rate (empirically you can "
+ "scale the learning rate by sqrt(batch_size) times).\n"
+ )
+
+ logging.info(
+ f"Using {len(train_dataset)} images and {len(valid_dataset)} for testing"
+ )
+ if task == task.DETECT:
+ logging.info("\nStarting object detector training...\n" + (50 * "-"))
+ else:
+ logging.info("\nStarting pose model training...\n" + (50 * "-"))
+
+ runner.fit(
+ train_dataloader,
+ valid_dataloader,
+ epochs=run_config["train_settings"]["epochs"],
+ display_iters=run_config["train_settings"]["display_iters"],
+ )
+
+
+def train_network(
+ config: str | Path,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ modelprefix: str = "",
+ device: str | None = None,
+ snapshot_path: str | Path | None = None,
+ detector_path: str | Path | None = None,
+ load_head_weights: bool = True,
+ batch_size: int | None = None,
+ epochs: int | None = None,
+ save_epochs: int | None = None,
+ detector_batch_size: int | None = None,
+ detector_epochs: int | None = None,
+ detector_save_epochs: int | None = None,
+ display_iters: int | None = None,
+ max_snapshots_to_keep: int | None = None,
+ pose_threshold: float | None = 0.1,
+ pytorch_cfg_updates: dict | None = None,
+) -> None:
+ """Trains a network for a project
+
+ Args:
+ config : path to the yaml config file of the project
+ shuffle : index of the shuffle we want to train on
+ trainingsetindex : training set index
+ modelprefix: directory containing the deeplabcut configuration files to use
+ to train the network (and where snapshots will be saved). By default, they
+ are assumed to exist in the project folder.
+ device: the torch device to train on (such as "cpu", "cuda", "mps")
+ snapshot_path: if resuming training, the snapshot from which to resume
+ detector_path: if resuming training of a top-down model, used to specify the
+ detector snapshot from which to resume
+ load_head_weights: if resuming training of a pose estimation model (either
+ through the `snapshot_path` attribute or the `resume_training_from` key in
+ the `pytorch_config.yaml` file), setting this to True also loads the weights
+ for the model head (equivalent to the `keepdeconvweights` for TensorFlow
+ models). Note that if you change the number of bodyparts, you need to set
+ this to false for re-training.
+ batch_size: overrides the batch size to train with
+ epochs: overrides the maximum number of epochs to train the model for
+ save_epochs: overrides the number of epochs between each snapshot save
+ detector_batch_size: Only for top-down models. Overrides the batch size with
+ which to train the detector.
+ detector_epochs: Only for top-down models. Overrides the maximum number of
+ epochs to train the model for. Setting to 0 means the detector will not be
+ trained.
+ detector_save_epochs: Only for top-down models. Overrides the number of epochs
+ between each snapshot of the detector is saved.
+ display_iters: overrides the number of iterations between each log of the loss
+ within an epoch
+ max_snapshots_to_keep: the maximum number of snapshots to save for each model
+ pose_threshold: Used for memory-replay. Pseudo-predictions with confidence lower
+ than this threshold are discarded for memory-replay
+ pytorch_cfg_updates: dict, optional, default = None.
+ A dictionary of updates to the pytorch config. The keys are the dot-separated
+ paths to the values to update in the config.
+ For example, to update the gpus to run the training on, you can use:
+ ```
+ pytorch_cfg_updates={"runner.gpus": [0,1,2,3]}
+ ```
+ To see the full list - check the pytorch_cfg.yaml file in your project folder
+ """
+ loader = DLCLoader(
+ config=config,
+ shuffle=shuffle,
+ trainset_index=trainingsetindex,
+ modelprefix=modelprefix,
+ )
+
+ if weight_init_cfg := loader.model_cfg["train_settings"].get("weight_init"):
+ weight_init = WeightInitialization.from_dict(weight_init_cfg)
+ if weight_init.memory_replay:
+ if weight_init.detector_snapshot_path is None:
+ raise ValueError(
+ "When fine-tuning a SuperAnimal model with memory replay, a "
+ "detector must be given as well so animals can be detected in "
+ "images to obtain pseudo-labels. Please update your weight "
+ "initialization so that `detector_snapshot_path` is not None."
+ )
+
+ print("Preparing data for memory replay (this can take some time)")
+ dataset_params = loader.get_dataset_parameters()
+ prepare_memory_replay(
+ config,
+ loader,
+ weight_init.dataset,
+ weight_init.snapshot_path,
+ weight_init.detector_snapshot_path,
+ device,
+ train_file="train.json",
+ max_individuals=dataset_params.max_num_animals,
+ pose_threshold=pose_threshold,
+ )
+
+ print("Loading memory replay data")
+ loader = COCOLoader(
+ project_root=loader.model_folder / "memory_replay",
+ model_config_path=loader.model_config_path,
+ train_json_filename="memory_replay_train.json",
+ )
+
+ if batch_size is not None:
+ loader.model_cfg["train_settings"]["batch_size"] = batch_size
+ if epochs is not None:
+ loader.model_cfg["train_settings"]["epochs"] = epochs
+ if save_epochs is not None:
+ loader.model_cfg["runner"]["snapshots"]["save_epochs"] = save_epochs
+ if display_iters is not None:
+ loader.model_cfg["train_settings"]["display_iters"] = display_iters
+
+ detector_cfg = loader.model_cfg.get("detector")
+ if detector_cfg is not None:
+ if detector_batch_size is not None:
+ detector_cfg["train_settings"]["batch_size"] = detector_batch_size
+ if detector_epochs is not None:
+ detector_cfg["train_settings"]["epochs"] = detector_epochs
+ if detector_save_epochs is not None:
+ detector_cfg["runner"]["snapshots"]["save_epochs"] = detector_save_epochs
+ if display_iters is not None:
+ detector_cfg["train_settings"]["display_iters"] = display_iters
+
+ if pytorch_cfg_updates is not None:
+ loader.update_model_cfg(pytorch_cfg_updates)
+
+ setup_file_logging(loader.model_folder / "train.txt")
+
+ logging.info("Training with configuration:")
+ config_utils.pretty_print(loader.model_cfg, print_fn=logging.info)
+
+ # fix seed for reproducibility
+ utils.fix_seeds(loader.model_cfg["train_settings"]["seed"])
+
+ # get the pose task
+ pose_task = Task(loader.model_cfg.get("method", "bu"))
+ if (
+ pose_task == Task.TOP_DOWN
+ and loader.model_cfg["detector"]["train_settings"]["epochs"] > 0
+ ):
+ logger_config = None
+ if loader.model_cfg.get("logger"):
+ logger_config = copy.deepcopy(loader.model_cfg["logger"])
+ logger_config["run_name"] += "-detector"
+
+ detector_run_config = loader.model_cfg["detector"]
+ detector_run_config["device"] = loader.model_cfg["device"]
+ detector_run_config["train_settings"]["weight_init"] = loader.model_cfg[
+ "train_settings"
+ ].get("weight_init")
+ train(
+ loader=loader,
+ run_config=detector_run_config,
+ task=Task.DETECT,
+ device=device,
+ logger_config=logger_config,
+ snapshot_path=detector_path,
+ max_snapshots_to_keep=max_snapshots_to_keep,
+ )
+
+ if loader.model_cfg["train_settings"]["epochs"] > 0:
+ train(
+ loader=loader,
+ run_config=loader.model_cfg,
+ task=pose_task,
+ device=device,
+ logger_config=loader.model_cfg.get("logger"),
+ snapshot_path=snapshot_path,
+ max_snapshots_to_keep=max_snapshots_to_keep,
+ load_head_weights=load_head_weights,
+ )
+
+ destroy_file_logging()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config-path", type=str)
+ parser.add_argument("--shuffle", type=int, default=1)
+ parser.add_argument("--train-ind", type=int, default=0)
+ parser.add_argument("--modelprefix", type=str, default="")
+ args = parser.parse_args()
+ train_network(
+ config=args.config_path,
+ shuffle=args.shuffle,
+ trainingsetindex=args.train_ind,
+ modelprefix=args.modelprefix,
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/apis/utils.py b/deeplabcut/pose_estimation_pytorch/apis/utils.py
new file mode 100644
index 0000000000..db1fe0321a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/utils.py
@@ -0,0 +1,723 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import logging
+import random
+from pathlib import Path
+from typing import Callable
+
+import albumentations as A
+import numpy as np
+import pandas as pd
+
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.core.engine import Engine
+from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters
+from deeplabcut.pose_estimation_pytorch.data.dlcloader import (
+ build_dlc_dataframe_columns,
+)
+from deeplabcut.pose_estimation_pytorch.data.postprocessor import (
+ build_bottom_up_postprocessor,
+ build_detector_postprocessor,
+ build_top_down_postprocessor,
+)
+from deeplabcut.pose_estimation_pytorch.data.preprocessor import (
+ build_bottom_up_preprocessor,
+ build_top_down_preprocessor,
+)
+from deeplabcut.pose_estimation_pytorch.data.transforms import build_transforms
+from deeplabcut.pose_estimation_pytorch.models import DETECTORS, PoseModel
+from deeplabcut.pose_estimation_pytorch.runners import (
+ build_inference_runner,
+ DetectorInferenceRunner,
+ DynamicCropper,
+ InferenceRunner,
+ PoseInferenceRunner,
+)
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import (
+ Snapshot,
+ TorchSnapshotManager,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.pose_estimation_pytorch.utils import resolve_device
+from deeplabcut.utils import auxfun_videos, auxiliaryfunctions
+
+
+def parse_snapshot_index_for_analysis(
+ cfg: dict,
+ model_cfg: dict,
+ snapshot_index: int | str | None,
+ detector_snapshot_index: int | str | None,
+) -> tuple[int, int | None]:
+ """Gets the index of the snapshots to use for data analysis (e.g. video analysis)
+
+ Args:
+ cfg: The project configuration.
+ model_cfg: The model configuration.
+ snapshot_index: The index of the snapshot to use, if one was given by the user.
+ detector_snapshot_index: The index of the detector snapshot to use, if one
+ was given by the user.
+
+ Returns:
+ snapshot_index: the snapshot index to use for analysis
+ detector_snapshot_index: the detector index to use for analysis, or None if no
+ detector should be used
+ """
+ if snapshot_index is None:
+ snapshot_index = cfg["snapshotindex"]
+ if snapshot_index == "all":
+ logging.warning(
+ "snapshotindex is set to 'all' (in the config.yaml file or as given to "
+ "`analyze_...`). Running data analysis with all snapshots is very "
+ "costly! Use the function 'evaluate_network' to choose the best the "
+ "snapshot. For now, changing snapshot index to -1. To evaluate another "
+ "snapshot, you can change the value in the config file or call "
+ "`analyze_videos` or `analyze_images` with your desired snapshot index."
+ )
+ snapshot_index = -1
+
+ pose_task = Task(model_cfg["method"])
+ if pose_task == Task.TOP_DOWN:
+ if detector_snapshot_index is None:
+ detector_snapshot_index = cfg.get("detector_snapshotindex", -1)
+
+ if detector_snapshot_index == "all":
+ logging.warning(
+ f"detector_snapshotindex is set to '{detector_snapshot_index}' (in the "
+ "config.yaml file or as given to `analyze_...`). Running data analysis "
+ "with all snapshots is very costly! Use 'evaluate_network' to choose "
+ "the best detector snapshot. For now, changing the detector snapshot "
+ "index to -1. To evaluate another detector snapshot, you can change "
+ "the value in the config file or call `analyze_videos` or "
+ "`analyze_images` with your desired detector snapshot index."
+ )
+ detector_snapshot_index = -1
+
+ else:
+ detector_snapshot_index = None
+
+ return snapshot_index, detector_snapshot_index
+
+
+def return_train_network_path(
+ config: str, shuffle: int = 1, trainingsetindex: int = 0, modelprefix: str = ""
+) -> tuple[Path, Path, Path]:
+ """
+ Args:
+ config: Full path of the config.yaml file as a string.
+ shuffle: The shuffle index to select for training
+ trainingsetindex: Which TrainingsetFraction to use (note that TrainingFraction
+ is a list in config.yaml)
+ modelprefix: the modelprefix for the model
+
+ Returns:
+ the path to the training pytorch pose configuration file
+ the path to the test pytorch pose configuration file
+ the path to the folder containing the snapshots
+ """
+ cfg = auxiliaryfunctions.read_config(config)
+ project_path = Path(cfg["project_path"])
+ train_frac = cfg["TrainingFraction"][trainingsetindex]
+ model_folder = auxiliaryfunctions.get_model_folder(
+ train_frac, shuffle, cfg, engine=Engine.PYTORCH, modelprefix=modelprefix
+ )
+ return (
+ project_path / model_folder / "train" / "pytorch_config.yaml",
+ project_path / model_folder / "test" / "pose_cfg.yaml",
+ project_path / model_folder / "train",
+ )
+
+
+def get_model_snapshots(
+ index: int | str,
+ model_folder: Path,
+ task: Task,
+) -> list[Snapshot]:
+ """
+ Args:
+ index: Passing an index returns the snapshot with that index (where snapshots
+ based on their number of training epochs, and the last snapshot is the
+ "best" model based on validation metrics if one exists). Passing "best"
+ returns the best snapshot from the training run. Passing "all" returns all
+ snapshots.
+ model_folder: The path to the folder containing the snapshots
+ task: The task for which to return the snapshot
+
+ Returns:
+ If index=="all", returns all snapshots. Otherwise, returns a list containing a
+ single snapshot, with the desired index.
+
+ Raises:
+ ValueError: If the index given is not valid
+ ValueError: If index=="best" but there is no saved best model
+ """
+ snapshot_manager = TorchSnapshotManager(
+ model_folder=model_folder, snapshot_prefix=task.snapshot_prefix
+ )
+ if isinstance(index, str) and index.lower() == "best":
+ best_snapshot = snapshot_manager.best()
+ if best_snapshot is None:
+ raise ValueError(f"No best snapshot found in {model_folder}")
+ snapshots = [best_snapshot]
+ elif isinstance(index, str) and index.lower() == "all":
+ snapshots = snapshot_manager.snapshots()
+ elif isinstance(index, int):
+ all_snapshots = snapshot_manager.snapshots()
+ if (
+ len(all_snapshots) == 0
+ or len(all_snapshots) <= index
+ or (index < 0 and len(all_snapshots) < -index)
+ ):
+ names = [s.path.name for s in all_snapshots]
+ raise ValueError(
+ f"Found {len(all_snapshots)} snapshots in {model_folder} (with names "
+ f"{names}) with prefix {snapshot_manager.snapshot_prefix}. Could "
+ f"not return snapshot with index {index}."
+ )
+
+ snapshots = [all_snapshots[index]]
+ else:
+ raise ValueError(f"Invalid snapshotindex: {index}")
+
+ return snapshots
+
+
+def get_scorer_uid(snapshot: Snapshot, detector_snapshot: Snapshot | None) -> str:
+ """
+ Args:
+ snapshot: the snapshot for which to get the scorer UID
+ detector_snapshot: if a top-down model is used with a detector, the detector
+ snapshot for which to get the scorer UID
+
+ Returns:
+ the uid to use for the scorer
+ """
+ snapshot_id = f"snapshot_{snapshot.uid()}"
+ if detector_snapshot is not None:
+ detect_id = detector_snapshot.uid()
+ snapshot_id = f"detector_{detect_id}_{snapshot_id}"
+ return snapshot_id
+
+
+def get_scorer_name(
+ cfg: dict,
+ shuffle: int,
+ train_fraction: float,
+ snapshot_index: int | None = None,
+ detector_index: int | None = None,
+ snapshot_uid: str | None = None,
+ modelprefix: str = "",
+) -> str:
+ """Get the scorer name for a particular PyTorch DeepLabCut shuffle
+
+ Args:
+ cfg: The project configuration.
+ shuffle: The index of the shuffle for which to get the scorer
+ train_fraction: The training fraction for the shuffle.
+ snapshot_index: The index of the snapshot used. If None, the value is loaded
+ from the project's config.yaml file.
+ detector_index: For top-down models, the index of the detector used. If None,
+ the value is loaded from the project's config.yaml file.
+ snapshot_uid: If the snapshot_uid is not None, this value will be used instead
+ of loading the snapshot and detector with given indices and calling
+ utils.get_scorer_uid.
+ modelprefix: The model prefix, if one was used.
+
+ Returns:
+ the scorer name
+ """
+ model_dir = Path(cfg["project_path"]) / auxiliaryfunctions.get_model_folder(
+ train_fraction,
+ shuffle,
+ cfg,
+ engine=Engine.PYTORCH,
+ modelprefix=modelprefix,
+ )
+ train_dir = model_dir / "train"
+ model_cfg = read_config_as_dict(str(train_dir / Engine.PYTORCH.pose_cfg_name))
+ net_type = model_cfg["net_type"]
+ pose_task = Task(model_cfg["method"])
+
+ if snapshot_uid is None:
+ if snapshot_index is None:
+ snapshot_index = auxiliaryfunctions.get_snapshot_index_for_scorer(
+ "snapshotindex", cfg["snapshotindex"]
+ )
+ if detector_index is None:
+ detector_index = auxiliaryfunctions.get_snapshot_index_for_scorer(
+ "detector_snapshotindex", cfg["detector_snapshotindex"]
+ )
+
+ snapshot = get_model_snapshots(snapshot_index, train_dir, pose_task)[0]
+ detector_snapshot = None
+ if detector_index is not None and pose_task == Task.TOP_DOWN:
+ detector_snapshot = get_model_snapshots(
+ detector_index, train_dir, Task.DETECT
+ )[0]
+
+ snapshot_uid = get_scorer_uid(snapshot, detector_snapshot)
+
+ task, date = cfg["Task"], cfg["date"]
+ name = "".join([p.capitalize() for p in net_type.split("_")])
+ return f"DLC_{name}_{task}{date}shuffle{shuffle}_{snapshot_uid}"
+
+
+def list_videos_in_folder(
+ data_path: str | list[str],
+ video_type: str | None,
+ shuffle: bool = False,
+) -> list[Path]:
+ """
+ Args:
+ data_path: Path or list of paths to folders containing videos
+ video_type: The type of video to filter for
+ shuffle: If the paths point to directories, whether to shuffle the order of
+ videos in the directory.
+
+ Returns:
+ The paths of videos to analyze.
+ """
+ if not isinstance(data_path, list):
+ data_path = [data_path]
+ video_paths = [Path(p) for p in data_path]
+
+ videos = []
+ for path in video_paths:
+ if path.is_dir():
+ if not video_type:
+ video_suffixes = ["." + ext for ext in auxfun_videos.SUPPORTED_VIDEOS]
+ else:
+ video_suffixes = [video_type]
+
+ suffixes = [s if s.startswith(".") else "." + s for s in video_suffixes]
+ videos_in_dir = [file for file in path.iterdir() if file.suffix in suffixes]
+ if shuffle:
+ random.shuffle(videos_in_dir)
+ videos += videos_in_dir
+ else:
+ assert (
+ path.exists()
+ ), f"Could not find the video: {path}. Check access rights."
+ videos.append(path)
+
+ return videos
+
+
+def ensure_multianimal_df_format(df_predictions: pd.DataFrame) -> pd.DataFrame:
+ """
+ Convert dataframe to 'multianimal' format (with an "individuals" columns index)
+
+ Args:
+ df_predictions: the dataframe to convert
+
+ Returns:
+ the dataframe in MA format
+ """
+ df_predictions_ma = df_predictions.copy()
+ try:
+ df_predictions_ma.columns.get_level_values("individuals").unique().tolist()
+ except KeyError:
+ new_cols = pd.MultiIndex.from_tuples(
+ [(col[0], "animal", col[1], col[2]) for col in df_predictions_ma.columns],
+ names=["scorer", "individuals", "bodyparts", "coords"],
+ )
+ df_predictions_ma.columns = new_cols
+ return df_predictions_ma
+
+
+def _image_names_to_df_index(
+ image_names: list[str],
+ image_name_to_index: Callable[[str], tuple[str, ...]] | None = None,
+) -> pd.MultiIndex | list[str]:
+ """
+ Creates index for predictions dataframe.
+ This method is used in build_predictions_dataframe, but also in build_bboxes_dict_for_dataframe.
+ It is important that these two methods return objects with the same index / keys.
+
+ Args:
+ image_names: list of image names
+ image_name_to_index, optional: a transform to apply on each image_name
+ """
+
+ if image_name_to_index is not None:
+ return pd.MultiIndex.from_tuples(
+ [image_name_to_index(image_name) for image_name in image_names]
+ )
+ else:
+ return image_names
+
+
+def build_predictions_dataframe(
+ scorer: str,
+ predictions: dict[str, dict[str, np.ndarray]],
+ parameters: PoseDatasetParameters,
+ image_name_to_index: Callable[[str], tuple[str, ...]] | None = None,
+) -> pd.DataFrame:
+ """
+ Builds a pandas DataFrame from pose prediction data. The resulting DataFrame
+ includes properly formatted indices and column names for compatibility with
+ DeepLabCut workflows.
+
+ Args:
+ scorer: The name of the scorer used to generate the predictions.
+ predictions: A dictionary where each key is an image name and its value is
+ another dictionary. The inner dictionary contains prediction data for
+ "bodyparts" and optionally "unique_bodyparts". The "bodyparts" and
+ "unique_bodyparts" data arrays are expected to be 3-dimensional, containing
+ pose predictions in format (num_predicted_individuals, num_bodyparts, 3).
+ parameters: Dataset-specific parameters required for constructing DataFrame
+ columns.
+ image_name_to_index: A callable function that takes an image name and returns
+ a tuple representing the DataFrame index. If None, indices will be
+ generated without transformation.
+
+ Returns:
+ A pandas DataFrame containing the processed prediction data for all provided
+ images. The DataFrame index corresponds to the image names or their
+ transformed values (if `image_name_to_index` is provided). The DataFrame
+ columns are constructed using the provided scorer and parameters.
+ """
+ image_names = []
+ prediction_data = []
+ for image_name, image_predictions in predictions.items():
+ image_data = image_predictions["bodyparts"][..., :3].reshape(-1)
+ if "unique_bodyparts" in image_predictions:
+ image_data = np.concatenate(
+ [image_data, image_predictions["unique_bodyparts"][..., :3].reshape(-1)]
+ )
+ image_names.append(image_name)
+ prediction_data.append(image_data)
+
+ index = _image_names_to_df_index(image_names, image_name_to_index)
+
+ return pd.DataFrame(
+ prediction_data,
+ index=index,
+ columns=build_dlc_dataframe_columns(
+ scorer=scorer,
+ parameters=parameters,
+ with_likelihood=True,
+ ),
+ )
+
+
+def build_bboxes_dict_for_dataframe(
+ predictions: dict[str, dict[str, np.ndarray]],
+ image_name_to_index: Callable[[str], tuple[str, ...]] | None = None,
+) -> dict:
+ """
+ Creates a dictionary with bounding boxes from predictions.
+
+ The keys of the dictionary are the same as the index of the dataframe created by
+ build_predictions_dataframe. Therefore, the structures returned by
+ build_predictions_dataframe and by build_bboxes_dict_for_dataframe can be accessed
+ with the same keys.
+
+ Args:
+ predictions: Dictionary containing the evaluation results
+ image_name_to_index: a transform to apply on each image_name
+
+ Returns:
+ Dictionary with sames keys as in the dataframe returned by
+ build_predictions_dataframe, and respective bounding boxes and scores, if any.
+ """
+
+ image_names = []
+ bboxes_data = []
+ for image_name, image_predictions in predictions.items():
+ image_names.append(image_name)
+ if "bboxes" in image_predictions and "bbox_scores" in image_predictions:
+ bboxes_data.append(
+ (image_predictions["bboxes"], image_predictions["bbox_scores"])
+ )
+
+ index = _image_names_to_df_index(image_names, image_name_to_index)
+
+ return dict(zip(index, bboxes_data))
+
+
+def get_inference_runners(
+ model_config: dict,
+ snapshot_path: str | Path,
+ max_individuals: int | None = None,
+ num_bodyparts: int | None = None,
+ num_unique_bodyparts: int | None = None,
+ batch_size: int = 1,
+ device: str | None = None,
+ with_identity: bool = False,
+ transform: A.BaseCompose | None = None,
+ detector_batch_size: int = 1,
+ detector_path: str | Path | None = None,
+ detector_transform: A.BaseCompose | None = None,
+ dynamic: DynamicCropper | None = None,
+) -> tuple[InferenceRunner, InferenceRunner | None]:
+ """Builds the runners for pose estimation
+
+ Args:
+ model_config: the pytorch configuration file
+ snapshot_path: the path of the snapshot from which to load the weights
+ max_individuals: the maximum number of individuals per image (if None, uses the
+ individuals defined in the model_config metadata)
+ num_bodyparts: the number of bodyparts predicted by the model (if None, uses the
+ bodyparts defined in the model_config metadata)
+ num_unique_bodyparts: the number of unique_bodyparts predicted by the model (if
+ None, uses the unique bodyparts defined in the model_config metadata)
+ batch_size: the batch size to use for the pose model.
+ with_identity: whether the pose model has an identity head
+ device: if defined, overwrites the device selection from the model config
+ transform: the transform for pose estimation. if None, uses the transform
+ defined in the config.
+ detector_batch_size: the batch size to use for the detector
+ detector_path: the path to the detector snapshot from which to load weights,
+ for top-down models (if a detector runner is needed)
+ detector_transform: the transform for object detection. if None, uses the
+ transform defined in the config.
+ dynamic: The DynamicCropper used for video inference, or None if dynamic
+ cropping should not be used. Only for bottom-up pose estimation models.
+ Should only be used when creating inference runners for video pose
+ estimation with batch size 1.
+
+ Returns:
+ a runner for pose estimation
+ a runner for detection, if detector_path is not None
+ """
+ if max_individuals is None:
+ max_individuals = len(model_config["metadata"]["individuals"])
+ if num_bodyparts is None:
+ num_bodyparts = len(model_config["metadata"]["bodyparts"])
+ if num_unique_bodyparts is None:
+ num_unique_bodyparts = len(model_config["metadata"]["unique_bodyparts"])
+
+ pose_task = Task(model_config["method"])
+ if device is None:
+ device = resolve_device(model_config)
+
+ if transform is None:
+ transform = build_transforms(model_config["data"]["inference"])
+
+ detector_runner = None
+ if pose_task == Task.BOTTOM_UP:
+ pose_preprocessor = build_bottom_up_preprocessor(
+ color_mode=model_config["data"]["colormode"],
+ transform=transform,
+ )
+ pose_postprocessor = build_bottom_up_postprocessor(
+ max_individuals=max_individuals,
+ num_bodyparts=num_bodyparts,
+ num_unique_bodyparts=num_unique_bodyparts,
+ with_identity=with_identity,
+ )
+ else:
+ # FIXME: Cannot run detectors on MPS
+ detector_device = device
+ if device == "mps":
+ detector_device = "cpu"
+
+ crop_cfg = model_config["data"]["inference"].get("top_down_crop", {})
+ width, height = crop_cfg.get("width", 256), crop_cfg.get("height", 256)
+ margin = crop_cfg.get("margin", 0)
+
+ pose_preprocessor = build_top_down_preprocessor(
+ color_mode=model_config["data"]["colormode"],
+ transform=transform,
+ top_down_crop_size=(width, height),
+ top_down_crop_margin=margin,
+ )
+ pose_postprocessor = build_top_down_postprocessor(
+ max_individuals=max_individuals,
+ num_bodyparts=num_bodyparts,
+ num_unique_bodyparts=num_unique_bodyparts,
+ )
+
+ if detector_path is not None:
+ if detector_transform is None:
+ detector_transform = build_transforms(
+ model_config["detector"]["data"]["inference"]
+ )
+
+ detector_config = model_config["detector"]["model"]
+ if "pretrained" in detector_config:
+ detector_config["pretrained"] = False
+
+ detector_runner = build_inference_runner(
+ task=Task.DETECT,
+ model=DETECTORS.build(detector_config),
+ device=detector_device,
+ snapshot_path=detector_path,
+ batch_size=detector_batch_size,
+ preprocessor=build_bottom_up_preprocessor(
+ color_mode=model_config["detector"]["data"]["colormode"],
+ transform=detector_transform,
+ ),
+ postprocessor=build_detector_postprocessor(
+ max_individuals=max_individuals,
+ ),
+ load_weights_only=model_config["detector"]["runner"].get(
+ "load_weights_only",
+ None,
+ ),
+ )
+
+ pose_runner = build_inference_runner(
+ task=pose_task,
+ model=PoseModel.build(model_config["model"]),
+ device=device,
+ snapshot_path=snapshot_path,
+ batch_size=batch_size,
+ preprocessor=pose_preprocessor,
+ postprocessor=pose_postprocessor,
+ dynamic=dynamic,
+ load_weights_only=model_config["runner"].get("load_weights_only", None),
+ )
+ return pose_runner, detector_runner
+
+
+def get_detector_inference_runner(
+ model_config: dict,
+ snapshot_path: str | Path,
+ batch_size: int = 1,
+ device: str | None = None,
+ max_individuals: int | None = None,
+ transform: A.BaseCompose | None = None,
+) -> DetectorInferenceRunner:
+ """Builds an inference runner for object detection.
+
+ Args:
+ model_config: the pytorch configuration file
+ snapshot_path: the path of the snapshot from which to load the weights
+ max_individuals: the maximum number of individuals per image
+ batch_size: the batch size to use for the pose model.
+ device: if defined, overwrites the device selection from the model config
+ transform: the transform for pose estimation. if None, uses the transform
+ defined in the config.
+
+ Returns:
+ an inference runner for object detection
+ """
+ if device is None:
+ device = resolve_device(model_config)
+ elif device == "mps": # FIXME(niels): Cannot run detectors on MPS
+ device = "cpu"
+
+ if max_individuals is None:
+ max_individuals = len(model_config["metadata"]["individuals"])
+
+ det_cfg = model_config["detector"]
+ if transform is None:
+ transform = build_transforms(det_cfg["data"]["inference"])
+
+ if "pretrained" in det_cfg["model"]:
+ det_cfg["model"]["pretrained"] = False
+
+ preprocessor = build_bottom_up_preprocessor(det_cfg["data"]["colormode"], transform)
+ postprocessor = build_detector_postprocessor(max_individuals=max_individuals)
+ runner = build_inference_runner(
+ task=Task.DETECT,
+ model=DETECTORS.build(det_cfg["model"]),
+ device=device,
+ snapshot_path=snapshot_path,
+ batch_size=batch_size,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ load_weights_only=det_cfg["runner"].get("load_weights_only", None),
+ )
+
+ if not isinstance(runner, DetectorInferenceRunner):
+ raise RuntimeError(f"Failed to build DetectorInferenceRunner: {model_config}")
+
+ return runner
+
+
+def get_pose_inference_runner(
+ model_config: dict,
+ snapshot_path: str | Path,
+ batch_size: int = 1,
+ device: str | None = None,
+ max_individuals: int | None = None,
+ transform: A.BaseCompose | None = None,
+ dynamic: DynamicCropper | None = None,
+) -> PoseInferenceRunner:
+ """Builds an inference runner for pose estimation.
+
+ Args:
+ model_config: the pytorch configuration file
+ snapshot_path: the path of the snapshot from which to load the weights
+ max_individuals: the maximum number of individuals per image
+ batch_size: the batch size to use for the pose model.
+ device: if defined, overwrites the device selection from the model config
+ transform: the transform for pose estimation. if None, uses the transform
+ defined in the config.
+ dynamic: The DynamicCropper used for video inference, or None if dynamic
+ cropping should not be used. Only for bottom-up pose estimation models.
+ Should only be used when creating inference runners for video pose
+ estimation with batch size 1.
+
+ Returns:
+ an inference runner for pose estimation
+ """
+ pose_task = Task(model_config["method"])
+ metadata = model_config["metadata"]
+ num_bodyparts = len(metadata["bodyparts"])
+ num_unique = len(metadata["unique_bodyparts"])
+ with_identity = bool(metadata["with_identity"])
+ if max_individuals is None:
+ max_individuals = len(metadata["individuals"])
+
+ if device is None:
+ device = resolve_device(model_config)
+
+ if transform is None:
+ transform = build_transforms(model_config["data"]["inference"])
+
+ if pose_task == Task.BOTTOM_UP:
+ pose_preprocessor = build_bottom_up_preprocessor(
+ color_mode=model_config["data"]["colormode"],
+ transform=transform,
+ )
+ pose_postprocessor = build_bottom_up_postprocessor(
+ max_individuals=max_individuals,
+ num_bodyparts=num_bodyparts,
+ num_unique_bodyparts=num_unique,
+ with_identity=with_identity,
+ )
+ else:
+ crop_cfg = model_config["data"]["inference"].get("top_down_crop", {})
+ width, height = crop_cfg.get("width", 256), crop_cfg.get("height", 256)
+ margin = crop_cfg.get("margin", 0)
+
+ pose_preprocessor = build_top_down_preprocessor(
+ color_mode=model_config["data"]["colormode"],
+ transform=transform,
+ top_down_crop_size=(width, height),
+ top_down_crop_margin=margin,
+ )
+ pose_postprocessor = build_top_down_postprocessor(
+ max_individuals=max_individuals,
+ num_bodyparts=num_bodyparts,
+ num_unique_bodyparts=num_unique,
+ )
+
+ runner = build_inference_runner(
+ task=pose_task,
+ model=PoseModel.build(model_config["model"]),
+ device=device,
+ snapshot_path=snapshot_path,
+ batch_size=batch_size,
+ preprocessor=pose_preprocessor,
+ postprocessor=pose_postprocessor,
+ dynamic=dynamic,
+ load_weights_only=model_config["runner"].get("load_weights_only", None),
+ )
+ if not isinstance(runner, PoseInferenceRunner):
+ raise RuntimeError(f"Failed to build PoseInferenceRunner for {model_config}")
+
+ return runner
diff --git a/deeplabcut/pose_estimation_pytorch/apis/videos.py b/deeplabcut/pose_estimation_pytorch/apis/videos.py
new file mode 100644
index 0000000000..7d0ce69516
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/videos.py
@@ -0,0 +1,783 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import copy
+import logging
+import pickle
+import time
+from pathlib import Path
+from typing import Any
+
+import albumentations as A
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+
+import deeplabcut.pose_estimation_pytorch.apis.utils as utils
+import deeplabcut.pose_estimation_pytorch.runners.shelving as shelving
+from deeplabcut.core.engine import Engine
+from deeplabcut.pose_estimation_pytorch.apis.tracklets import (
+ convert_detections2tracklets,
+)
+from deeplabcut.pose_estimation_pytorch.runners import InferenceRunner, DynamicCropper
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.refine_training_dataset.stitch import stitch_tracklets
+from deeplabcut.utils import auxiliaryfunctions, VideoReader
+
+
+class VideoIterator(VideoReader):
+ """A class to iterate over videos, with possible added context"""
+
+ def __init__(
+ self,
+ video_path: str | Path,
+ context: list[dict[str, Any]] | None = None,
+ cropping: list[int] | None = None,
+ ) -> None:
+ super().__init__(str(video_path))
+ self._context = context
+ self._index = 0
+ self._crop = cropping is not None
+ if self._crop:
+ self.set_bbox(*cropping)
+
+ def set_crop(self, cropping: list[int] | None = None) -> None:
+ """Sets the cropping parameters for the video."""
+ self._crop = cropping is not None
+ if self._crop:
+ self.set_bbox(*cropping)
+ else:
+ self.set_bbox(0, 1, 0, 1, relative=True)
+
+ def get_context(self) -> list[dict[str, Any]] | None:
+ if self._context is None:
+ return None
+
+ return copy.deepcopy(self._context)
+
+ def set_context(self, context: list[dict[str, Any]] | None) -> None:
+ if context is None:
+ self._context = None
+ return
+
+ self._context = copy.deepcopy(context)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self) -> np.ndarray | tuple[str, dict[str, Any]]:
+ frame = self.read_frame(crop=self._crop)
+ if frame is None:
+ self._index = 0
+ self.reset()
+ raise StopIteration
+
+ # Otherwise ValueError: At least one stride in the given numpy array is negative,
+ # and tensors with negative strides are not currently supported. (You can probably
+ # work around this by making a copy of your array with array.copy().)
+ frame = frame.copy()
+ if self._context is None:
+ self._index += 1
+ return frame
+
+ context = copy.deepcopy(self._context[self._index])
+ self._index += 1
+ return frame, context
+
+
+def video_inference(
+ video: str | Path | VideoIterator,
+ pose_runner: InferenceRunner,
+ detector_runner: InferenceRunner | None = None,
+ cropping: list[int] | None = None,
+ shelf_writer: shelving.ShelfWriter | None = None,
+ robust_nframes: bool = False,
+) -> list[dict[str, np.ndarray]]:
+ """Runs inference on a video
+
+ Args:
+ video: The video to analyze
+ pose_runner: The pose runner to run inference with
+ detector_runner: When the pose model is a top-down model, a detector runner can
+ be given to obtain bounding boxes for the video. If the pose model is a
+ top-down model and no detector_runner is given, the bounding boxes must
+ already be set in the VideoIterator (see examples).
+ cropping: Optionally, video inference can be run on a cropped version of the
+ video. To do so, pass a list containing 4 elements to specify which area
+ of the video should be analyzed: ``[xmin, xmax, ymin, ymax]``.
+ shelf_writer: By default, data are dumped in a pickle file at the end of the
+ video analysis. Passing a shelf manager writes data to disk on-the-fly
+ using a "shelf" (a pickle-based, persistent, database-like object by
+ default, resulting in constant memory footprint). The returned list is
+ then empty.
+ robust_nframes: Evaluate a video's number of frames in a robust manner. This
+ option is slower (as the whole video is read frame-by-frame), but does not
+ rely on metadata, hence its robustness against file corruption.
+
+ Returns:
+ Predictions for each frame in the video. If a shelf_manager is given, this list
+ will be empty and the predictions will exclusively be stored in the file written
+ by the shelf.
+
+ Examples:
+ Bottom-up video analysis:
+ >>> import deeplabcut.pose_estimation_pytorch as pep
+ >>> from deeplabcut.core.config_utils import read_config_as_dict
+ >>> model_cfg = read_config_as_dict("pytorch_config.yaml")
+ >>> runner = pep.get_pose_inference_runner(model_cfg, "snapshot.pt")
+ >>> video_predictions = pep.video_inference("video.mp4", runner)
+ >>>
+
+ Top-down video analysis:
+ >>> import deeplabcut.pose_estimation_pytorch as pep
+ >>> from deeplabcut.core.config_utils import read_config_as_dict
+ >>> model_cfg = read_config_as_dict("pytorch_config.yaml")
+ >>> runner = pep.get_pose_inference_runner(model_cfg, "snapshot.pt")
+ >>> d_runner = pep.get_pose_inference_runner(model_cfg, "snapshot-detector.pt")
+ >>> video_predictions = pep.video_inference("video.mp4", runner, d_runner)
+ >>>
+
+ Top-Down pose estimation with pre-computed bounding boxes:
+ >>> import numpy as np
+ >>> import deeplabcut.pose_estimation_pytorch as pep
+ >>> from deeplabcut.core.config_utils import read_config_as_dict
+ >>>
+ >>> video_iterator = pep.VideoIterator("video.mp4")
+ >>> video_iterator.set_context([
+ >>> { # frame 1 context
+ >>> "bboxes": np.array([[12, 17, 4, 5]]), # format (x0, y0, w, h)
+ >>> },
+ >>> { # frame 1 context
+ >>> "bboxes": np.array([[12, 17, 4, 5], [18, 92, 54, 32]]),
+ >>> },
+ >>> ...
+ >>> ])
+ >>> model_cfg = read_config_as_dict("pytorch_config.yaml")
+ >>> runner = pep.get_pose_inference_runner(model_cfg, "snapshot.pt")
+ >>> video_predictions = pep.video_inference(video_iterator, runner)
+ >>>
+ """
+ if not isinstance(video, VideoIterator):
+ video = VideoIterator(str(video), cropping=cropping)
+ elif cropping is not None:
+ video.set_crop(cropping)
+
+ n_frames = video.get_n_frames(robust=robust_nframes)
+ vid_w, vid_h = video.dimensions
+ print(f"Starting to analyze {video.video_path}")
+ print(
+ f"Video metadata: \n"
+ f" Overall # of frames: {n_frames}\n"
+ f" Duration of video [s]: {n_frames / max(1, video.fps):.2f}\n"
+ f" fps: {video.fps}\n"
+ f" resolution: w={vid_w}, h={vid_h}\n"
+ )
+
+ if detector_runner is not None:
+ print(f"Running detector with batch size {detector_runner.batch_size}")
+ bbox_predictions = detector_runner.inference(images=tqdm(video))
+ video.set_context(bbox_predictions)
+
+ print(f"Running pose prediction with batch size {pose_runner.batch_size}")
+ if shelf_writer is not None:
+ shelf_writer.open()
+
+ predictions = pose_runner.inference(images=tqdm(video), shelf_writer=shelf_writer)
+ if shelf_writer is not None:
+ shelf_writer.close()
+
+ if shelf_writer is None and len(predictions) != n_frames:
+ tip_url = "https://deeplabcut.github.io/DeepLabCut/docs/recipes/io.html"
+ header = "#tips-on-video-re-encoding-and-preprocessing"
+ logging.warning(
+ f"The video metadata indicates that there {n_frames} in the video, but "
+ f"only {len(predictions)} were able to be processed. This can happen if "
+ "the video is corrupted. You can try to fix the issue by re-encoding your "
+ f"video (tips on how to do that: {tip_url}{header})"
+ )
+
+ return predictions
+
+
+def analyze_videos(
+ config: str,
+ videos: str | list[str],
+ videotype: str | None = None,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ save_as_csv: bool = False,
+ in_random_order: bool = False,
+ snapshot_index: int | str | None = None,
+ detector_snapshot_index: int | str | None = None,
+ device: str | None = None,
+ destfolder: str | None = None,
+ batch_size: int | None = None,
+ detector_batch_size: int | None = None,
+ dynamic: tuple[bool, float, int] = (False, 0.5, 10),
+ modelprefix: str = "",
+ use_shelve: bool = False,
+ robust_nframes: bool = False,
+ transform: A.Compose | None = None,
+ auto_track: bool | None = True,
+ n_tracks: int | None = None,
+ animal_names: list[str] | None = None,
+ calibrate: bool = False,
+ identity_only: bool | None = False,
+ overwrite: bool = False,
+ cropping: list[int] | None = None,
+ save_as_df: bool = False,
+) -> str:
+ """Makes prediction based on a trained network.
+
+ The index of the trained network is specified by parameters in the config file
+ (in particular the variable 'snapshot_index').
+
+ Args:
+ config: full path of the config.yaml file for the project
+ videos: a str (or list of strings) containing the full paths to videos for
+ analysis or a path to the directory, where all the videos with same
+ extension are stored.
+ videotype: checks for the extension of the video in case the input to the video
+ is a directory. Only videos with this extension are analyzed. If left
+ unspecified, keeps videos with extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv').
+ shuffle: An integer specifying the shuffle index of the training dataset used for
+ training the network.
+ trainingsetindex: Integer specifying which TrainingsetFraction to use.
+ save_as_csv: For multi-animal projects and when `auto_track=True`, passed
+ along to the `stitch_tracklets` method to save tracks as CSV.
+ in_random_order: Whether or not to analyze videos in a random order. This is
+ only relevant when specifying a video directory in `videos`.
+ device: the device to use for video analysis
+ destfolder: specifies the destination folder for analysis data. If ``None``,
+ the path of the video is used. Note that for subsequent analysis this
+ folder also needs to be passed
+ snapshot_index: index (starting at 0) of the snapshot to use to analyze the
+ videos. To evaluate the last one, use -1. For example if we have
+ - snapshot-0.pt
+ - snapshot-50.pt
+ - snapshot-100.pt
+ - snapshot-best.pt
+ and we want to evaluate snapshot-50.pt, snapshotindex should be 1. If None,
+ the snapshot index is loaded from the project configuration.
+ detector_snapshot_index: (only for top-down models) index of the detector
+ snapshot to use, used in the same way as ``snapshot_index``
+ dynamic: (state, detection threshold, margin) triplet. If the state is true,
+ then dynamic cropping will be performed. That means that if an object is
+ detected (i.e. any body part > detection threshold), then object boundaries
+ are computed according to the smallest/largest x position and
+ smallest/largest y position of all body parts. This window is expanded by
+ the margin and from then on only the posture within this crop is analyzed
+ (until the object is lost, i.e. < detection threshold). The current position
+ is utilized for updating the crop window for the next frame (this is why the
+ margin is important and should be set large enough given the movement of the
+ animal).
+ modelprefix: directory containing the deeplabcut models to use when evaluating
+ the network. By default, they are assumed to exist in the project folder.
+ batch_size: the batch size to use for inference. Takes the value from the
+ project config as a default.
+ detector_batch_size: the batch size to use for detector inference. Takes the
+ value from the project config as a default.
+ transform: Optional custom transforms to apply to the video
+ overwrite: Overwrite any existing videos
+ use_shelve: By default, data are dumped in a pickle file at the end of the video
+ analysis. Otherwise, data are written to disk on the fly using a "shelf";
+ i.e., a pickle-based, persistent, database-like object by default, resulting
+ in constant memory footprint.
+ robust_nframes: Evaluate a video's number of frames in a robust manner. This
+ option is slower (as the whole video is read frame-by-frame), but does not
+ rely on metadata, hence its robustness against file corruption.
+ auto_track: By default, tracking and stitching are automatically performed,
+ producing the final h5 data file. This is equivalent to the behavior for
+ single-animal projects.
+
+ If ``False``, one must run ``convert_detections2tracklets`` and
+ ``stitch_tracklets`` afterwards, in order to obtain the h5 file.
+ n_tracks: Number of tracks to reconstruct. By default, taken as the number of
+ individuals defined in the config.yaml. Another number can be passed if the
+ number of animals in the video is different from the number of animals the
+ model was trained on.
+ animal_names: If you want the names given to individuals in the labeled data
+ file, you can specify those names as a list here. If given and `n_tracks`
+ is None, `n_tracks` will be set to `len(animal_names)`. If `n_tracks` is not
+ None, then it must be equal to `len(animal_names)`. If it is not given, then
+ `animal_names` will be loaded from the `individuals` in the project
+ `config.yaml` file.
+ identity_only: sub-call for auto_track. If ``True`` and animal identity was
+ learned by the model, assembly and tracking rely exclusively on identity
+ prediction.
+ cropping: List of cropping coordinates as [x1, x2, y1, y2]. Note that the same
+ cropping parameters will then be used for all videos. If different video
+ crops are desired, run ``analyze_videos`` on individual videos with the
+ corresponding cropping coordinates.
+ save_as_df: Cannot be used when `use_shelve` is True. Saves the video
+ predictions (before tracking results) to an H5 file containing a pandas
+ DataFrame. If ``save_as_csv==True`` than the full predictions will also be
+ saved in a CSV file.
+
+ Returns:
+ The scorer used to analyze the videos
+ """
+ # Create the output folder
+ _validate_destfolder(destfolder)
+
+ # Load the project configuration
+ cfg = auxiliaryfunctions.read_config(config)
+ project_path = Path(cfg["project_path"])
+ train_fraction = cfg["TrainingFraction"][trainingsetindex]
+ model_folder = project_path / auxiliaryfunctions.get_model_folder(
+ train_fraction,
+ shuffle,
+ cfg,
+ modelprefix=modelprefix,
+ engine=Engine.PYTORCH,
+ )
+ train_folder = model_folder / "train"
+
+ # Read the inference configuration, load the model
+ model_cfg_path = train_folder / Engine.PYTORCH.pose_cfg_name
+ model_cfg = auxiliaryfunctions.read_plainconfig(model_cfg_path)
+ pose_task = Task(model_cfg["method"])
+
+ pose_cfg_path = model_folder / "test" / "pose_cfg.yaml"
+ pose_cfg = auxiliaryfunctions.read_plainconfig(pose_cfg_path)
+
+ snapshot_index, detector_snapshot_index = utils.parse_snapshot_index_for_analysis(
+ cfg, model_cfg, snapshot_index, detector_snapshot_index,
+ )
+
+ if cropping is None and cfg.get("cropping", False):
+ cropping = cfg["x1"], cfg["x2"], cfg["y1"], cfg["y2"]
+
+ # Get general project parameters
+ multi_animal = cfg["multianimalproject"]
+ bodyparts = model_cfg["metadata"]["bodyparts"]
+ unique_bodyparts = model_cfg["metadata"]["unique_bodyparts"]
+ individuals = model_cfg["metadata"]["individuals"]
+ max_num_animals = len(individuals)
+
+ if device is not None:
+ model_cfg["device"] = device
+
+ if batch_size is None:
+ batch_size = cfg.get("batch_size", 1)
+
+ if not multi_animal:
+ save_as_df = True
+ if use_shelve:
+ print(
+ "The ``use_shelve`` parameter cannot be used for single animal "
+ "projects. Setting ``use_shelve=False``."
+ )
+ use_shelve = False
+
+ dynamic = DynamicCropper.build(*dynamic)
+ if pose_task != Task.BOTTOM_UP and dynamic is not None:
+ print(
+ "Turning off dynamic cropping. It should only be used for bottom-up "
+ f"pose estimation models, but you are using a top-down model."
+ )
+ dynamic = None
+
+ snapshot = utils.get_model_snapshots(snapshot_index, train_folder, pose_task)[0]
+ print(f"Analyzing videos with {snapshot.path}")
+ pose_runner = utils.get_pose_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=snapshot.path,
+ max_individuals=max_num_animals,
+ batch_size=batch_size,
+ transform=transform,
+ dynamic=dynamic,
+ )
+ detector_runner = None
+
+ detector_path, detector_snapshot = None, None
+ if pose_task == Task.TOP_DOWN:
+ if detector_snapshot_index is None:
+ raise ValueError(
+ "Cannot run videos analysis for top-down models without a detector "
+ "snapshot! Please specify your desired detector_snapshotindex in your "
+ "project's configuration file."
+ )
+
+ if detector_batch_size is None:
+ detector_batch_size = cfg.get("detector_batch_size", 1)
+
+ detector_snapshot = utils.get_model_snapshots(
+ detector_snapshot_index, train_folder, Task.DETECT
+ )[0]
+ print(f" -> Using detector {detector_snapshot.path}")
+ detector_runner = utils.get_detector_inference_runner(
+ model_config=model_cfg,
+ snapshot_path=detector_snapshot.path,
+ max_individuals=max_num_animals,
+ batch_size=detector_batch_size,
+ )
+
+ dlc_scorer = utils.get_scorer_name(
+ cfg,
+ shuffle,
+ train_fraction,
+ snapshot_uid=utils.get_scorer_uid(snapshot, detector_snapshot),
+ modelprefix=modelprefix,
+ )
+
+ # Reading video and init variables
+ videos = utils.list_videos_in_folder(videos, videotype, shuffle=in_random_order)
+ for video in videos:
+ if destfolder is None:
+ output_path = video.parent
+ else:
+ output_path = Path(destfolder)
+
+ output_prefix = video.stem + dlc_scorer
+ output_pkl = output_path / f"{output_prefix}_full.pickle"
+
+ video_iterator = VideoIterator(video, cropping=cropping)
+
+ shelf_writer = None
+ if use_shelve:
+ shelf_writer = shelving.ShelfWriter(
+ pose_cfg=pose_cfg,
+ filepath=output_pkl,
+ num_frames=video_iterator.get_n_frames(robust=robust_nframes),
+ )
+
+ if not overwrite and output_pkl.exists():
+ print(f"Video {video} already analyzed at {output_pkl}!")
+ else:
+ runtime = [time.time()]
+ predictions = video_inference(
+ video=video_iterator,
+ pose_runner=pose_runner,
+ detector_runner=detector_runner,
+ shelf_writer=shelf_writer,
+ robust_nframes=robust_nframes,
+ )
+ runtime.append(time.time())
+ metadata = _generate_metadata(
+ cfg=cfg,
+ pytorch_config=model_cfg,
+ dlc_scorer=dlc_scorer,
+ train_fraction=train_fraction,
+ batch_size=batch_size,
+ cropping=cropping,
+ runtime=(runtime[0], runtime[1]),
+ video=video_iterator,
+ robust_nframes=robust_nframes,
+ )
+
+ with open(output_path / f"{output_prefix}_meta.pickle", "wb") as f:
+ pickle.dump(metadata, f, pickle.HIGHEST_PROTOCOL)
+
+ if use_shelve and save_as_df:
+ print("Can't ``save_as_df`` as ``use_shelve=True``. Skipping.")
+
+ if not use_shelve:
+ output_data = _generate_output_data(pose_cfg, predictions)
+ with open(output_pkl, "wb") as f:
+ pickle.dump(output_data, f, pickle.HIGHEST_PROTOCOL)
+
+ if save_as_df:
+ create_df_from_prediction(
+ predictions=predictions,
+ multi_animal=multi_animal,
+ model_cfg=model_cfg,
+ dlc_scorer=dlc_scorer,
+ output_path=output_path,
+ output_prefix=output_prefix,
+ save_as_csv=save_as_csv,
+ )
+
+ if multi_animal:
+ _generate_assemblies_file(
+ full_data_path=output_pkl,
+ output_path=output_path / f"{output_prefix}_assemblies.pickle",
+ num_bodyparts=len(bodyparts),
+ num_unique_bodyparts=len(unique_bodyparts),
+ )
+
+ if auto_track:
+ convert_detections2tracklets(
+ config=config,
+ videos=str(video),
+ videotype=videotype,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ overwrite=False,
+ identity_only=identity_only,
+ destfolder=str(output_path),
+ )
+ stitch_tracklets(
+ config,
+ [str(video)],
+ videotype,
+ shuffle,
+ trainingsetindex,
+ n_tracks=n_tracks,
+ animal_names=animal_names,
+ destfolder=str(output_path),
+ save_as_csv=save_as_csv,
+ )
+
+ print(
+ "The videos are analyzed. Now your research can truly start!\n"
+ "You can create labeled videos with 'create_labeled_video'.\n"
+ "If the tracking is not satisfactory for some videos, consider expanding the "
+ "training set. You can use the function 'extract_outlier_frames' to extract a "
+ "few representative outlier frames.\n"
+ )
+
+ return dlc_scorer
+
+
+def create_df_from_prediction(
+ predictions: list[dict[str, np.ndarray]],
+ dlc_scorer: str,
+ multi_animal: bool,
+ model_cfg: dict,
+ output_path: str | Path,
+ output_prefix: str | Path,
+ save_as_csv: bool = False,
+) -> pd.DataFrame:
+ pred_bodyparts = np.stack([p["bodyparts"][..., :3] for p in predictions])
+ pred_unique_bodyparts = None
+ if len(predictions) > 0 and "unique_bodyparts" in predictions[0]:
+ pred_unique_bodyparts = np.stack([p["unique_bodyparts"] for p in predictions])
+
+ output_h5 = Path(output_path) / f"{output_prefix}.h5"
+ output_pkl = Path(output_path) / f"{output_prefix}_full.pickle"
+
+ bodyparts = model_cfg["metadata"]["bodyparts"]
+ unique_bodyparts = model_cfg["metadata"]["unique_bodyparts"]
+ individuals = model_cfg["metadata"]["individuals"]
+ n_individuals = len(individuals)
+
+ print(f"Saving results in {output_h5} and {output_pkl}")
+ coords = ["x", "y", "likelihood"]
+ cols = [[dlc_scorer], bodyparts, coords]
+ cols_names = ["scorer", "bodyparts", "coords"]
+
+ if multi_animal:
+ cols.insert(1, individuals)
+ cols_names.insert(1, "individuals")
+
+ results_df_index = pd.MultiIndex.from_product(cols, names=cols_names)
+ pred_bodyparts = pred_bodyparts[:, :n_individuals]
+ df = pd.DataFrame(
+ pred_bodyparts.reshape((len(pred_bodyparts), -1)),
+ columns=results_df_index,
+ index=range(len(pred_bodyparts)),
+ )
+ if pred_unique_bodyparts is not None:
+ unique_columns = [dlc_scorer], ["single"], unique_bodyparts, coords
+ df_u = pd.DataFrame(
+ pred_unique_bodyparts.reshape((len(pred_unique_bodyparts), -1)),
+ columns=pd.MultiIndex.from_product(unique_columns, names=cols_names),
+ index=range(len(pred_unique_bodyparts)),
+ )
+ df = df.join(df_u, how="outer")
+
+ df.to_hdf(output_h5, key="df_with_missing", format="table", mode="w")
+ if save_as_csv:
+ df.to_csv(output_h5.with_suffix(".csv"))
+ return df
+
+
+def _generate_assemblies_file(
+ full_data_path: Path,
+ output_path: Path,
+ num_bodyparts: int,
+ num_unique_bodyparts: int,
+) -> None:
+ """Generates the assemblies file from predictions"""
+ if full_data_path.exists():
+ with open(full_data_path, "rb") as f:
+ data = pickle.load(f)
+
+ else:
+ data = shelving.ShelfReader(full_data_path)
+ data.open()
+
+ num_frames = data["metadata"]["nframes"]
+ str_width = data["metadata"].get("key_str_width")
+ if str_width is None:
+ keys = [k for k in data.keys() if k != "metadata"]
+ str_width = len(keys[0]) - len("frame")
+
+ assemblies = dict(single=dict())
+ for frame_index in range(num_frames):
+ frame_key = "frame" + str(frame_index).zfill(str_width)
+ predictions = data[frame_key]
+
+ keypoint_preds = predictions["coordinates"][0]
+ keypoint_scores = predictions["confidence"]
+
+ bpts = np.stack(keypoint_preds[:num_bodyparts])
+ scores = np.stack(keypoint_scores[:num_bodyparts])
+ preds = np.concatenate([bpts, scores], axis=-1)
+
+ keypoint_id_scores = predictions.get("identity")
+ if keypoint_id_scores is not None:
+ keypoint_id_scores = np.stack(keypoint_id_scores[:num_bodyparts])
+ keypoint_pred_ids = np.argmax(keypoint_id_scores, axis=2)
+ keypoint_pred_ids = np.expand_dims(keypoint_pred_ids, axis=-1)
+ else:
+ num_bpts, num_preds = preds.shape[:2]
+ keypoint_pred_ids = -np.ones((num_bpts, num_preds, 1))
+
+ # reshape to (num_preds, num_bpts, 4)
+ preds = np.concatenate([preds, keypoint_pred_ids], axis=-1)
+ preds = preds.transpose((1, 0, 2))
+
+ # remove all-missing predictions
+ mask = ~np.all(preds < 0, axis=(1, 2))
+ preds = preds[mask]
+
+ assemblies[frame_index] = preds
+
+ if num_unique_bodyparts > 0:
+ unique_bpts = np.stack(keypoint_preds[num_bodyparts:])
+ unique_scores = np.stack(keypoint_scores[num_bodyparts:])
+ unique_preds = np.concatenate([unique_bpts, unique_scores], axis=-1)
+ unique_preds = unique_preds.transpose((1, 0, 2))
+ assemblies["single"][frame_index] = unique_preds[0] # single prediction
+
+ with open(output_path, "wb") as file:
+ pickle.dump(assemblies, file, pickle.HIGHEST_PROTOCOL)
+
+ if isinstance(data, shelving.ShelfReader):
+ data.close()
+
+
+def _validate_destfolder(destfolder: str | None) -> None:
+ """Checks that the destfolder for video analysis is valid"""
+ if destfolder is not None and destfolder != "":
+ output_folder = Path(destfolder)
+ if not output_folder.exists():
+ print(f"Creating the output folder {output_folder}")
+ output_folder.mkdir(parents=True)
+
+ assert Path(
+ output_folder
+ ).is_dir(), f"Output folder must be a directory: you passed '{output_folder}'"
+
+
+def _generate_metadata(
+ cfg: dict,
+ pytorch_config: dict,
+ dlc_scorer: str,
+ train_fraction: int,
+ batch_size: int,
+ cropping: list[int] | None,
+ runtime: tuple[float, float],
+ video: VideoIterator,
+ robust_nframes: bool = False,
+) -> dict:
+ w, h = video.dimensions
+ if cropping is None:
+ cropping_parameters = [0, w, 0, h]
+ else:
+ if not len(cropping) == 4:
+ raise ValueError(
+ "The cropping parameters should be exactly 4 values: [x_min, x_max, "
+ f"y_min, y_max]. Found {cropping}"
+ )
+ cropping_parameters = cropping
+
+ metadata = {
+ "start": runtime[0],
+ "stop": runtime[1],
+ "run_duration": runtime[1] - runtime[0],
+ "Scorer": dlc_scorer,
+ "pytorch-config": pytorch_config,
+ "fps": video.fps,
+ "batch_size": batch_size,
+ "frame_dimensions": (w, h),
+ "nframes": video.get_n_frames(robust=robust_nframes),
+ "iteration (active-learning)": cfg["iteration"],
+ "training set fraction": train_fraction,
+ "cropping": cropping is not None,
+ "cropping_parameters": cropping_parameters,
+ "individuals": pytorch_config["metadata"]["individuals"],
+ "bodyparts": pytorch_config["metadata"]["bodyparts"],
+ "unique_bodyparts": pytorch_config["metadata"]["unique_bodyparts"],
+ }
+ return {"data": metadata}
+
+
+def _generate_output_data(
+ pose_config: dict,
+ predictions: list[dict[str, np.ndarray]],
+) -> dict:
+ str_width = int(np.ceil(np.log10(len(predictions))))
+ output = {
+ "metadata": {
+ "nms radius": pose_config.get("nmsradius"),
+ "minimal confidence": pose_config.get("minconfidence"),
+ "sigma": pose_config.get("sigma", 1),
+ "PAFgraph": pose_config.get("partaffinityfield_graph"),
+ "PAFinds": pose_config.get(
+ "paf_best",
+ np.arange(len(pose_config.get("partaffinityfield_graph", []))),
+ ),
+ "all_joints": [[i] for i in range(len(pose_config["all_joints"]))],
+ "all_joints_names": [
+ pose_config["all_joints_names"][i]
+ for i in range(len(pose_config["all_joints"]))
+ ],
+ "nframes": len(predictions),
+ "key_str_width": str_width,
+ }
+ }
+
+ for frame_num, frame_predictions in enumerate(predictions):
+ key = "frame" + str(frame_num).zfill(str_width)
+ # shape (num_assemblies, num_bpts, 3)
+ bodyparts = frame_predictions["bodyparts"]
+ # shape (num_bpts, num_assemblies, 3)
+ bodyparts = bodyparts.transpose((1, 0, 2))
+ coordinates = [bpt[:, :2] for bpt in bodyparts]
+ scores = [bpt[:, 2:3] for bpt in bodyparts]
+
+ # full pickle has bodyparts and unique bodyparts in same array
+ num_unique = 0
+ if "unique_bodyparts" in frame_predictions:
+ unique_bpts = frame_predictions["unique_bodyparts"].transpose((1, 0, 2))
+ coordinates += [bpt[:, :2] for bpt in unique_bpts]
+ scores += [bpt[:, 2:] for bpt in unique_bpts]
+ num_unique = len(unique_bpts)
+
+ output[key] = {
+ "coordinates": (coordinates,),
+ "confidence": scores,
+ "costs": None,
+ }
+
+ if "bboxes" in frame_predictions:
+ output[key]["bboxes"] = frame_predictions["bboxes"]
+ output[key]["bbox_scores"] = frame_predictions["bbox_scores"]
+
+ if "identity_scores" in frame_predictions:
+ # Reshape id scores from (num_assemblies, num_bpts, num_individuals)
+ # to the original DLC full pickle format: (num_bpts, num_assem, num_ind)
+ id_scores = frame_predictions["identity_scores"]
+ id_scores = id_scores.transpose((1, 0, 2))
+ output[key]["identity"] = [bpt_id_scores for bpt_id_scores in id_scores]
+
+ if num_unique > 0:
+ # needed for create_video_with_all_detections to display unique bpts
+ num_assem, num_ind = id_scores.shape[1:]
+ output[key]["identity"] += [
+ -1 * np.ones((num_assem, num_ind)) for i in range(num_unique)
+ ]
+
+ return output
diff --git a/deeplabcut/pose_estimation_pytorch/apis/visualization.py b/deeplabcut/pose_estimation_pytorch/apis/visualization.py
new file mode 100644
index 0000000000..f4912f55db
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/apis/visualization.py
@@ -0,0 +1,695 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Methods to help with visualization of model outputs"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import cv2
+import matplotlib.collections as collections
+import matplotlib.colors as colors
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from tqdm import tqdm
+
+import deeplabcut.core.visualization as visualization
+import deeplabcut.pose_estimation_pytorch.apis.utils as utils
+import deeplabcut.pose_estimation_pytorch.data as data
+import deeplabcut.pose_estimation_pytorch.data.preprocessor as preprocessor
+import deeplabcut.pose_estimation_pytorch.models as models
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.core.engine import Engine
+from deeplabcut.pose_estimation_pytorch.task import Task
+from deeplabcut.utils import auxiliaryfunctions
+
+
+def create_labeled_images(
+ predictions: dict[str, dict[str, np.ndarray | np.ndarray]],
+ out_folder: str | Path,
+ pcutoff: float = 0.6,
+ bboxes_pcutoff: float = 0.6,
+ mode: str = "bodypart",
+ cmap: str | colors.Colormap = "rainbow",
+ dot_size: int = 12,
+ alpha_value: float = 0.7,
+ skeleton: list[tuple[int, int]] | None = None,
+ skeleton_color: str = "k",
+ close_figure_after_save: bool = True,
+):
+ """Plots model predictions on images.
+
+ Args:
+ predictions: The predictions to plot. A dictionary mapping image paths to
+ the predictions made by the model on that image. The predictions should
+ contain a "bodyparts" key, mapping to an array of shape (max_individuals,
+ num_bodyparts, 3) containing predicted bodyparts. If there are any unique
+ bodyparts predicted, then it should also contain a "unique_bodyparts" key,
+ mapping to an array of shape (1, num_bodyparts, 3) containing the predicted
+ unique bodyparts.
+ out_folder: The folder where model predictions should be saved.
+ pcutoff: The p-cutoff score above which predicted bodyparts are displayed with
+ a "⋅" marker, and below which they are displayed with a "X" marker.
+ bboxes_pcutoff: The bounding box cutoff score, below which predicted bounding
+ boxes are shown with a dashed line.
+ mode: One of "bodypart", "individual". Whether to color predictions by
+ bodypart or individual.
+ cmap: The colormap to use to plot predictions.
+ dot_size: The size of the bodypart prediction markers.
+ alpha_value: The transparency value of the bodypart prediction markers.
+ skeleton: If skeletons should be plotted, the list of bodyparts that constitute
+ the skeletons.
+ skeleton_color: The color with which to plot the skeleton, if one is given.
+ close_figure_after_save: Whether to close figures after saving the labeled
+ images to disk.
+ """
+ out_folder = Path(out_folder)
+ out_folder.mkdir(exist_ok=True)
+
+ color_by_individual = mode == "individual"
+ if isinstance(cmap, str):
+ cmap = plt.cm.get_cmap(cmap)
+
+ for image_path, image_predictions in predictions.items():
+ # Load frame
+ frame = Image.open(str(image_path))
+
+ # get pose predictions
+ pred = image_predictions["bodyparts"]
+ total_idv, total_bodyparts = pred.shape[:2]
+ unique_pred = None
+ if "unique_bodyparts" in image_predictions:
+ unique_pred = image_predictions["unique_bodyparts"][0]
+ total_idv += 1
+ total_bodyparts += len(unique_pred)
+
+ # create plot
+ fig, ax = plt.subplots()
+ ax.imshow(frame)
+
+ # plot bodyparts
+ for idx, pose in enumerate(pred):
+ xy, scores = pose[:, :2], pose[:, 2]
+ mask = scores > pcutoff
+ if np.sum(pose) < 0 or np.sum(mask) <= 0:
+ continue
+
+ bones = []
+ if skeleton is not None:
+ for idx_1, idx_2 in skeleton:
+ if scores[idx_1] > pcutoff and scores[idx_2] > pcutoff:
+ bones.append(xy[[idx_1, idx_2]])
+
+ kwargs = dict(s=dot_size)
+ if color_by_individual:
+ kwargs["c"] = cmap(idx / total_idv)
+ else:
+ c = np.linspace(0, 1, total_bodyparts)[:len(pose)][mask]
+ kwargs["c"] = c
+ kwargs["cmap"] = cmap
+
+ xy = xy[mask]
+ ax.scatter(xy[:, 0], xy[:, 1], **kwargs)
+ if len(bones) > 0:
+ ax.add_collection(
+ collections.LineCollection(
+ bones, colors=skeleton_color, alpha=alpha_value
+ )
+ )
+
+ # plot unique bodyparts
+ if unique_pred is not None:
+ xy, scores = unique_pred[:, :2], unique_pred[:, 2]
+ mask = scores > pcutoff
+ if np.sum(mask) <= 0:
+ continue
+
+ kwargs = dict(s=dot_size)
+ if color_by_individual:
+ kwargs["c"] = cmap(1)
+ else:
+ c = np.linspace(0, 1, total_bodyparts)
+ kwargs["c"] = c[-len(unique_pred):][mask]
+ kwargs["cmap"] = cmap
+
+ xy = xy[mask]
+ ax.scatter(xy[:, 0], xy[:, 1], **kwargs)
+
+ # plot bounding boxes
+ if "bboxes" in image_predictions:
+ bboxes = image_predictions["bboxes"]
+ bbox_scores = image_predictions["bbox_scores"]
+ for idx, (bbox, score) in enumerate(zip(bboxes, bbox_scores)):
+ if score <= bboxes_pcutoff:
+ continue
+
+ xmin, ymin, w, h = bbox
+ rect = plt.Rectangle(
+ (xmin, ymin), w, h, fill=False, edgecolor="green", linewidth=2
+ )
+ ax.add_patch(rect)
+
+ # save predictions
+ output_path = out_folder / f"predictions_{Path(image_path).stem}.png"
+ fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
+ fig.savefig(output_path)
+
+ if close_figure_after_save:
+ plt.close(fig)
+
+ if close_figure_after_save:
+ plt.close()
+
+
+@torch.no_grad()
+def extract_model_outputs(
+ images: list[str] | list[Path],
+ model: models.PoseModel,
+ pre_processor: preprocessor.Preprocessor,
+ device: str = "auto",
+ context: list[dict[str, np.ndarray]] | None = None,
+) -> list[dict[str, np.ndarray]]:
+ """Obtains the outputs for a model for a list of images
+
+ Args:
+ images: List of image paths for which to get model outputs.
+ model: The model for which to get model outputs.
+ pre_processor: The pre-processor used to prepare the images before giving them
+ to the model.
+ device: The device on which to run inference.
+ context: The context for each image to give to the pre-processor. For top-down
+ models, this context should contain the bounding boxes to use for each
+ image. This should be in a format:
+ [
+ {"bboxes": array of shape (num_bboxes, 4)}, # image 1 bboxes,
+ {"bboxes": array of shape (num_bboxes, 4)}, # image 2 bboxes,
+ ...,
+ {"bboxes": array of shape (num_bboxes, 4)}, # image N bboxes,
+ ]
+
+ Returns:
+ A list containing a dict for each input image, in the format:
+ {
+ inputs: a numpy array containing the inputs given to the model for the image
+ context: the context given alongside the image
+ outputs: a dict containing the model outputs
+ }
+ """
+ if context is not None and len(context) != len(images):
+ raise ValueError(
+ "When passing context along with the images (e.g. bounding boxes for "
+ "top-down models), there should be the same number of elements in the "
+ f"context as the number of images. Received {len(images)} images but "
+ f"{len(context)} contexts."
+ )
+
+ model = model.to(device)
+ model = model.eval()
+
+ model_data = []
+ for idx, image in enumerate(images):
+ image_context = {}
+ if context is not None:
+ image_context = context[idx]
+
+ inputs, image_context = pre_processor(image, image_context)
+ output = model(inputs.to(device))
+
+ for head, head_cfg in model.cfg["heads"].items():
+ if (
+ head_cfg["predictor"].get("apply_sigmoid", False)
+ or head_cfg["predictor"]["type"] == "PartAffinityFieldPredictor"
+ ):
+ if "heatmap" in output[head]:
+ output[head]["heatmap"] = F.sigmoid(output[head]["heatmap"])
+
+ output = {
+ head: {name: output.cpu().numpy() for name, output in head_outputs.items()}
+ for head, head_outputs in output.items()
+ }
+ model_data.append(
+ dict(inputs=inputs.cpu().numpy(), context=context, outputs=output)
+ )
+
+ return model_data
+
+
+def extract_maps(
+ config,
+ shuffle: int = 0,
+ trainingsetindex: int | str = 0,
+ device: str | None = None,
+ rescale: bool = False,
+ indices: list[int] | None = None,
+ extract_paf: bool = True,
+ modelprefix: str | None = "",
+ snapshot_index: int | str | None = None,
+ detector_snapshot_index: int | str | None = None,
+) -> dict:
+ """
+ Extracts the different maps output by DeepLabCut models, such as scoremaps, location
+ refinement fields and part-affinity fields.
+
+ Args:
+ config: Full path of the config.yaml file as a string.
+ shuffle: Index of the shuffle for which to extract maps
+ trainingset_index: Integer specifying which TrainingsetFraction to use. This
+ variable can also be set to "all".
+ rescale: Evaluate the model at the 'global_scale' variable (as set in the
+ test/pose_config.yaml file for a particular project). Every image will be
+ resized according to that scale and prediction will be compared to the
+ resized ground truth. The error will be reported in pixels at rescaled to
+ the *original* size. Example:
+ For a [200, 200] pixel image evaluated at ``global_scale=0.5``,
+ predictions are calculated on [100, 100] pixel images, compared to
+ ``0.5*ground truth`` and this error is then multiplied by 2!. The
+ evaluation images are also shown for the original size!
+ indices: Optionally, you can only obtain maps for a subset of images in your
+ dataset. The indices given here are the indices of the images for which
+ maps will be extracted.
+ modelprefix: Directory containing the deeplabcut models to use when evaluating
+ the network. By default, the models are assumed to exist in the project
+ folder.
+ snapshot_index: Index (starting at 0) of the snapshot we want to extract maps
+ with. To evaluate the last one, use -1. To extract maps for all snapshots,
+ use "all".
+ detector_snapshot_index: Only for TD models. If defined, uses the detector with
+ the given index for pose estimation. To extract maps for all detector
+ snapshots, use "all".
+
+ Returns:
+ a dict indexed by (trainingset_fraction, snapshot_index, image_index). For each
+ key, the item contains a tuple of:
+ (img, scmap, locref, paf, bpt_names, paf_graph, img_name, is_train)
+
+ Examples
+ --------
+ If you want to extract the data for image 0 and 103 (of the training set) for
+ model trained with shuffle 0.
+
+ >>> deeplabcut.extract_maps(config, 0, indices=[0, 103])
+ """
+ cfg = read_config_as_dict(config)
+
+ trainset_indices = [trainingsetindex]
+ if trainingsetindex == "all":
+ trainset_indices = [i for i in range(len(cfg["TrainingFraction"]))]
+ if snapshot_index is None:
+ snapshot_index = cfg["snapshotindex"]
+ if detector_snapshot_index is None:
+ detector_snapshot_index = cfg["detector_snapshotindex"]
+
+ extracted_maps = {}
+ for trainset_index in trainset_indices:
+ loader = data.DLCLoader(
+ config=config,
+ shuffle=shuffle,
+ trainset_index=trainset_index,
+ modelprefix=modelprefix,
+ )
+ extracted_maps[loader.train_fraction] = {}
+
+ # (img, scmap, locref, paf, bpt_names, paf_graph, img_name, is_train)
+ metadata = loader.model_cfg["metadata"]
+ bpt_names = metadata["bodyparts"] + metadata["unique_bodyparts"]
+ paf_graph = []
+ bpt_head_cfg = loader.model_cfg["model"]["heads"]["bodypart"]
+ if bpt_head_cfg["type"] == "DLCRNetHead":
+ paf_graph = bpt_head_cfg.get("predictor", {}).get("graph")
+ paf_indices = bpt_head_cfg.get("predictor", {}).get("edges_to_keep")
+ if paf_indices is not None:
+ paf_graph = [paf_graph[i] for i in paf_indices]
+
+ if device is not None:
+ loader.model_cfg["device"] = device
+ loader.model_cfg["device"] = utils.resolve_device(loader.model_cfg)
+ device = loader.model_cfg["device"]
+
+ if snapshot_index is None:
+ snapshot_index = -1
+ snapshots = utils.get_model_snapshots(
+ snapshot_index, loader.model_folder, loader.pose_task
+ )
+
+ image_paths = loader.df.index
+ if indices is not None:
+ image_paths = [image_paths[idx] for idx in indices]
+ if len(image_paths) > 0 and isinstance(image_paths[0], tuple):
+ image_paths = [Path(*img_path) for img_path in image_paths]
+
+ image_paths = [
+ (loader.project_path / img_path).resolve() for img_path in image_paths
+ ]
+
+ context = _get_context(image_paths, loader, detector_snapshot_index, device)
+ train_idx = set(loader.split["train"])
+ for snapshot in snapshots:
+ snapshot_id = snapshot.path.stem
+ extracted_maps[loader.train_fraction][snapshot_id] = {}
+ runner = utils.get_pose_inference_runner(
+ model_config=loader.model_cfg,
+ snapshot_path=snapshot.path,
+ )
+ results = extract_model_outputs(
+ image_paths,
+ runner.model,
+ runner.preprocessor,
+ runner.device,
+ context=context,
+ )
+ for idx, result in enumerate(results):
+ image_idx = idx
+ if indices is not None:
+ image_idx = indices[idx]
+
+ # key can be just image_idx, or (image_idx, bbox_idx) for TD models
+ keys, images, outputs = _collect_model_outputs(
+ loader.pose_task, result, image_idx
+ )
+ for key, image, output in zip(keys, images, outputs):
+ parsed = _parse_model_outputs(
+ image,
+ output,
+ strides={
+ k: runner.model.get_stride(k)
+ for k in runner.model.heads.keys()
+ },
+ denormalize_image=True,
+ )
+ img_name = image_paths[idx].stem
+ if isinstance(key, tuple):
+ bbox_id = key[1]
+ img_name += f"_bbox{bbox_id:03d}"
+
+ is_train = image_idx in train_idx
+ extracted_maps[loader.train_fraction][snapshot_id][key] = (
+ *parsed,
+ None,
+ bpt_names,
+ paf_graph,
+ img_name,
+ is_train,
+ )
+
+ # img, scmap, locref, paf, peaks, bpt_names, paf_graph, img_name, is_train
+ return extracted_maps
+
+
+def extract_save_all_maps(
+ config: str | Path,
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ comparison_bodyparts: str | list[str] = "all",
+ extract_paf: bool = True,
+ all_paf_in_one: bool = True,
+ device: str | None = None,
+ rescale: bool = False,
+ indices: list[int] | None = None,
+ modelprefix: str | None = "",
+ snapshot_index: int | str | None = None,
+ detector_snapshot_index: int | str | None = None,
+ dest_folder: str | Path | None = None,
+):
+ """
+ Extracts the scoremap, location refinement field and part affinity field prediction
+ of the model. The maps will be rescaled to the size of the input image and stored
+ in the corresponding model folder in /evaluation-results-pytorch.
+
+ Args:
+ config: Full path of the config.yaml file as a string.
+ shuffle: Index of the shuffle for which to extract maps
+ trainingset_index: Integer specifying which TrainingsetFraction to use. This
+ variable can also be set to "all".
+ comparison_bodyparts: The average error will be computed for those body parts
+ only (Has to be a subset of the body parts).
+ extract_paf: Extract part affinity fields by default. Note that turning it off
+ will make the function much faster.
+ all_paf_in_one: By default, all part affinity fields are displayed on a single
+ frame. If false, individual fields are shown on separate frames.
+ indices: Optionally, you can only obtain maps for a subset of images in your
+ dataset. The indices given here are the indices of the images for which
+ maps will be extracted.
+ modelprefix: Directory containing the deeplabcut models to use when evaluating
+ the network. By default, the models are assumed to exist in the project
+ folder.
+ snapshot_index: Index (starting at 0) of the snapshot we want to extract maps
+ with. To evaluate the last one, use -1. To extract maps for all snapshots,
+ use "all".
+ detector_snapshot_index: Only for TD models. If defined, uses the detector with
+ the given index for pose estimation. To extract maps for all detector
+ snapshots, use "all".
+
+ Examples
+ --------
+ Calculated maps for images 0, 1 and 33.
+ >>> deeplabcut.extract_save_all_maps(
+ >>> "/analysis/project/reaching-task/config.yaml",
+ >>> shuffle=1,
+ >>> indices=[0, 1, 33]
+ >>> )
+
+ """
+ cfg = read_config_as_dict(config)
+ maps = extract_maps(
+ config,
+ shuffle=shuffle,
+ trainingsetindex=trainingsetindex,
+ device=device,
+ rescale=rescale,
+ indices=indices,
+ snapshot_index=snapshot_index,
+ detector_snapshot_index=detector_snapshot_index,
+ modelprefix=modelprefix,
+ )
+ bpts_to_plot = auxiliaryfunctions.intersection_of_body_parts_and_ones_given_by_user(
+ cfg, comparison_bodyparts
+ )
+
+ print("Saving plots...")
+ for frac, values in maps.items():
+ dest_folder = _get_maps_folder(cfg, frac, shuffle, modelprefix, dest_folder)
+ dest_folder.mkdir(exist_ok=True)
+ for snap, maps in values.items():
+ for image_idx, image_maps in tqdm(maps.items()):
+ (
+ image,
+ scmap,
+ locref,
+ paf,
+ peaks,
+ bpt_names,
+ paf_graph,
+ image_path,
+ training_image,
+ ) = image_maps
+
+ if not extract_paf:
+ paf = []
+
+ label = "train" if training_image else "test"
+ img_w, img_h = image.shape[1], image.shape[0]
+ scmap = _prepare_maps_for_plotting(scmap, (img_w, img_h))
+ if scmap is None:
+ raise ValueError("Cannot plot heatmaps - none output by the model")
+
+ locref = _prepare_maps_for_plotting(locref, (img_w, img_h))
+ if locref is not None:
+ locref = locref.reshape((img_h, img_w, -1, 2))
+ paf = _prepare_maps_for_plotting(paf, (img_w, img_h))
+
+ visualization.generate_model_output_plots(
+ output_folder=dest_folder,
+ image_name=Path(image_path).stem,
+ bodypart_names=bpt_names,
+ bodyparts_to_plot=bpts_to_plot,
+ image=image,
+ scmap=scmap,
+ locref=locref,
+ paf=paf,
+ paf_graph=paf_graph,
+ paf_all_in_one=all_paf_in_one,
+ paf_colormap=cfg["colormap"],
+ output_suffix=f"{label}_{shuffle}_{frac}_{snap}",
+ )
+
+
+def _get_context(
+ image_paths: list[Path],
+ loader: data.Loader,
+ detector_snapshot_index: int | str | None,
+ device: str,
+) -> list[dict] | None:
+ """Gets the context for top-down pose estimation models"""
+ if loader.pose_task != Task.TOP_DOWN:
+ return None
+
+ det_snapshots = []
+ if detector_snapshot_index is not None:
+ det_snapshots = utils.get_model_snapshots(
+ detector_snapshot_index, loader.model_folder, Task.DETECT
+ )
+
+ if detector_snapshot_index is None or len(det_snapshots) == 0:
+ if detector_snapshot_index is None:
+ print("No ``detector_snapshot_index`` given.")
+ else:
+ print(f"No detector snapshots found in {loader.model_folder}")
+ print("Using GT bboxes to extract maps for this top-down model")
+
+ bboxes_train = loader.ground_truth_bboxes(mode="train")
+ bboxes_test = loader.ground_truth_bboxes(mode="test")
+ bboxes = {**bboxes_train, **bboxes_test}
+ return [
+ dict(bboxes=bboxes[str(img_path)]["bboxes"]) for img_path in image_paths
+ ]
+
+ detector_runner = utils.get_detector_inference_runner(
+ model_config=loader.model_cfg,
+ snapshot_path=det_snapshots[-1].path,
+ device=device,
+ )
+ return detector_runner.inference(image_paths)
+
+
+def _collect_model_outputs(
+ task: Task,
+ result: dict,
+ image_idx: int,
+) -> tuple[list, list, list]:
+ """Collects the model outputs into data that can be processed.
+
+ Args:
+ task: Whether the model is a bottom-up or top-down model.
+ result: A result output by ``extract_model_outputs``.
+ image_idx: The index of the image
+
+ Returns: keys, images, outputs
+ keys: The key for each image to plot.
+ images: The images to plot for this input image (a single image for bottom-up
+ models, and the number of bounding boxes for top-down models).
+ outputs: The model outputs for each image.
+ """
+ if task == Task.TOP_DOWN:
+ keys, images, outputs = [], [], []
+
+ # parse each input individually
+ num_bboxes = len(result["inputs"])
+ for bbox_idx in range(num_bboxes):
+ keys.append((image_idx, bbox_idx))
+ images.append(result["inputs"][bbox_idx])
+ outputs.append(
+ {
+ head: {k: v[bbox_idx] for k, v in head_outputs.items()}
+ for head, head_outputs in result["outputs"].items()
+ }
+ )
+ return keys, images, outputs
+
+ # remove batch dimension
+ return (
+ [image_idx],
+ [result["inputs"][0]],
+ [
+ {
+ head: {k: v[0] for k, v in head_outputs.items()}
+ for head, head_outputs in result["outputs"].items()
+ }
+ ],
+ )
+
+
+def _parse_model_outputs(
+ image: np.ndarray,
+ outputs: dict[str, dict[str, np.ndarray]],
+ strides: dict[str, int],
+ denormalize_image: bool = True,
+) -> tuple[np.ndarray, list[np.ndarray], list[np.ndarray], list[np.ndarray]]:
+ """Parses the model outputs into a format that can easily be plotted.
+
+ Args:
+ image: The image used to obtain the outputs.
+ outputs: The model outputs.
+ strides: The total stride for each model head.
+ denormalize_image: Whether the image was normalized and should be de-normalized.
+
+ Returns: (img, scmap, locref, paf)
+ img: The (de-normalized) image used as input.
+ scmap: The score maps output by the model.
+ locref: The locref fields output by the model.
+ paf: The part-affinity fields output by the model.
+ """
+ image = image.transpose((1, 2, 0))
+ if denormalize_image:
+ image = image * np.array([0.229, 0.224, 0.225])
+ image = image + np.array([0.485, 0.456, 0.406])
+ image = np.clip(image, 0, 1)
+
+ heatmaps = [h for h in outputs["bodypart"].get("heatmap", [])]
+ locrefs = [m * strides["bodypart"] for m in outputs["bodypart"].get("locref", [])]
+ paf = [p for p in outputs["bodypart"].get("paf", [])]
+
+ if "unique_bodypart" in outputs:
+ heatmaps += [h for h in outputs["unique_bodypart"].get("heatmap", [])]
+ locrefs += [
+ strides["unique_bodypart"] * m
+ for m in outputs["unique_bodypart"].get("locref", [])
+ ]
+
+ return image, heatmaps, locrefs, paf
+
+
+def _prepare_maps_for_plotting(
+ maps: list[np.ndarray], image_size: tuple[int, int]
+) -> np.ndarray | None:
+ """Resizes all maps to the image size and concatenates them into a single array.
+
+ Args:
+ maps: The maps that will be shown on the image.
+ image_size: The (width, height) of the input image.
+
+ Returns:
+ The resized maps, or None if the list of maps was empty.
+ """
+ if len(maps) == 0:
+ return None
+
+ img_w, img_h = image_size
+ return np.stack(
+ [
+ cv2.resize(map_, (img_w, img_h), interpolation=cv2.INTER_LINEAR)
+ for map_ in maps
+ ],
+ axis=-1,
+ )
+
+
+def _get_maps_folder(
+ cfg: dict,
+ train_frac: float,
+ shuffle: int,
+ model_prefix: str | None,
+ dest_folder: str | Path | None,
+) -> Path:
+ """Gets the destination folder for output maps"""
+ if dest_folder is None:
+ project_path = Path(cfg["project_path"])
+ eval_folder = auxiliaryfunctions.get_evaluation_folder(
+ trainFraction=train_frac,
+ shuffle=shuffle,
+ cfg=cfg,
+ engine=Engine.PYTORCH,
+ modelprefix=model_prefix,
+ )
+ dest_folder = project_path / eval_folder / "maps"
+
+ return Path(dest_folder)
diff --git a/deeplabcut/pose_estimation_pytorch/benchmark/__init__.py b/deeplabcut/pose_estimation_pytorch/benchmark/__init__.py
new file mode 100644
index 0000000000..117d127147
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/benchmark/__init__.py
@@ -0,0 +1,10 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
diff --git a/deeplabcut/pose_estimation_pytorch/benchmark/profile_HRNetCoAM.py b/deeplabcut/pose_estimation_pytorch/benchmark/profile_HRNetCoAM.py
new file mode 100644
index 0000000000..a2f8ad408b
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/benchmark/profile_HRNetCoAM.py
@@ -0,0 +1,20 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+# Script for reproducing results in Zhou* & Stoffl* et al. for BUCTD with CoAM
+
+# path=datapath
+# results=resultspath or put numbers
+
+# train model
+
+# evaluate and
+# check if predicted is close to result
diff --git a/deeplabcut/pose_estimation_pytorch/config/__init__.py b/deeplabcut/pose_estimation_pytorch/config/__init__.py
new file mode 100644
index 0000000000..21e72d4d4b
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/__init__.py
@@ -0,0 +1,27 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
+ make_basic_project_config,
+ make_pytorch_pose_config,
+ make_pytorch_test_config,
+)
+from deeplabcut.pose_estimation_pytorch.config.utils import (
+ available_detectors,
+ available_models,
+ update_config,
+ update_config_by_dotpath,
+)
+# For backwards compatibility
+from deeplabcut.core.config import (
+ read_config_as_dict,
+ write_config,
+ pretty_print,
+)
\ No newline at end of file
diff --git a/deeplabcut/pose_estimation_pytorch/config/animaltokenpose/animaltokenpose_base.yaml b/deeplabcut/pose_estimation_pytorch/config/animaltokenpose/animaltokenpose_base.yaml
new file mode 100644
index 0000000000..4c45a347d5
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/animaltokenpose/animaltokenpose_base.yaml
@@ -0,0 +1,49 @@
+# TODO: This default configuration file needs to be reviewed so it matches the original
+ # base TokenPose configuration, as defined in
+ # https://github.com/leeyegy/TokenPose/blob/main/experiments/coco/tokenpose/tokenpose_b_256_192_patch43_dim192_depth12_heads8.yaml
+method: td # Need to add a detector
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w32
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ interpolate_branches: false
+ increased_channel_count: false # changes backbone_output_channels to 128 when true
+ backbone_output_channels: 32
+ neck:
+ type: Transformer
+ feature_size:
+ - 64
+ - 64
+ patch_size:
+ - 4
+ - 4
+ num_keypoints: "num_bodyparts"
+ channels: 32
+ dim: 192
+ heads: 8
+ depth: 6
+ heads:
+ bodypart:
+ type: TransformerHead
+ target_generator:
+ type: HeatmapPlateauGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ generate_locref: false
+ criterion:
+ type: WeightedBCECriterion
+ predictor:
+ type: HeatmapPredictor
+ location_refinement: false
+ dim: 192
+ hidden_heatmap_dim: 384
+ heatmap_dim: 4096
+ apply_multi: true
+ heatmap_size:
+ - 64
+ - 64
+ apply_init: true
+ head_stride: 1
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_m.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_m.yaml
new file mode 100644
index 0000000000..2f444ba231
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_m.yaml
@@ -0,0 +1,19 @@
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_m
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 0.67
+ widen_factor: 0.75
+ backbone_output_channels: 768
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_s.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_s.yaml
new file mode 100644
index 0000000000..2cd01dbcfa
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_s.yaml
@@ -0,0 +1,19 @@
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_s
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 0.33
+ widen_factor: 0.5
+ backbone_output_channels: 512
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_x.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_x.yaml
new file mode 100644
index 0000000000..f598540d8a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/cspnext_x.yaml
@@ -0,0 +1,19 @@
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_x
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 1.33
+ widen_factor: 1.25
+ backbone_output_channels: 1280
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w18.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w18.yaml
new file mode 100644
index 0000000000..2bbb35ad76
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w18.yaml
@@ -0,0 +1,18 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+ train:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w18
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ interpolate_branches: false
+ increased_channel_count: false # changes backbone_output_channels to 128 when true
+ backbone_output_channels: 18
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w32.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w32.yaml
new file mode 100644
index 0000000000..a2e1a21cf6
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w32.yaml
@@ -0,0 +1,18 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+ train:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w32
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ interpolate_branches: false
+ increased_channel_count: false # changes backbone_output_channels to 128 when true
+ backbone_output_channels: 32
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w48.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w48.yaml
new file mode 100644
index 0000000000..9941090c53
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/hrnet_w48.yaml
@@ -0,0 +1,18 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+ train:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w48
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ interpolate_branches: false
+ increased_channel_count: false # changes backbone_output_channels to 128 when true
+ backbone_output_channels: 48
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_101.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_101.yaml
new file mode 100644
index 0000000000..287344f5ca
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_101.yaml
@@ -0,0 +1,18 @@
+model:
+ backbone:
+ type: ResNet
+ model_name: resnet101
+ output_stride: 16
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ backbone_output_channels: 2048
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
diff --git a/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_50.yaml b/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_50.yaml
new file mode 100644
index 0000000000..18b1467473
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/backbones/resnet_50.yaml
@@ -0,0 +1,18 @@
+model:
+ backbone:
+ type: ResNet
+ model_name: resnet50_gn
+ output_stride: 16
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ backbone_output_channels: 2048
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
\ No newline at end of file
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/aug_default.yaml b/deeplabcut/pose_estimation_pytorch/config/base/aug_default.yaml
new file mode 100644
index 0000000000..8ded47ed05
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/aug_default.yaml
@@ -0,0 +1,19 @@
+colormode: RGB
+inference:
+ normalize_images: true
+train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [0.5, 1.25]
+ translation: 0
+ covering: false
+ crop_sampling:
+ width: 448
+ height: 448
+ max_shift: 0.1
+ method: hybrid
+ gaussian_noise: 12.75
+ hist_eq: false
+ motion_blur: false
+ normalize_images: true
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/aug_top_down.yaml b/deeplabcut/pose_estimation_pytorch/config/base/aug_top_down.yaml
new file mode 100644
index 0000000000..61a03e3c3a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/aug_top_down.yaml
@@ -0,0 +1,21 @@
+colormode: RGB
+inference:
+ normalize_images: true
+ top_down_crop:
+ width: 256
+ height: 256
+train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [1.0, 1.0]
+ translation: 0
+ collate: null
+ covering: false
+ gaussian_noise: 12.75
+ hist_eq: false
+ motion_blur: false
+ normalize_images: true
+ top_down_crop:
+ width: 256
+ height: 256
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/base.yaml b/deeplabcut/pose_estimation_pytorch/config/base/base.yaml
new file mode 100644
index 0000000000..93d751bcf9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/base.yaml
@@ -0,0 +1,28 @@
+device: auto
+method: bu
+runner:
+ type: PoseTrainingRunner
+ gpus: null
+ key_metric: "test.mAP"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.0001
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-5 ], [ 1e-6 ] ]
+ milestones: [ 160, 190 ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 8
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 200
+ seed: 42
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/base_detector.yaml b/deeplabcut/pose_estimation_pytorch/config/base/base_detector.yaml
new file mode 100644
index 0000000000..ef3fbf656c
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/base_detector.yaml
@@ -0,0 +1,45 @@
+data:
+ colormode: RGB
+ inference:
+ normalize_images: true
+ train:
+ affine:
+ p: 0.5
+ rotation: 30
+ scaling: [ 1.0, 1.0 ]
+ translation: 40
+ collate:
+ type: ResizeFromDataSizeCollate
+ min_scale: 0.4
+ max_scale: 1.0
+ min_short_side: 128
+ max_short_side: 1152
+ multiple_of: 32
+ to_square: false
+ hflip: true
+ normalize_images: true
+device: auto
+runner:
+ type: DetectorTrainingRunner
+ key_metric: "test.mAP@50:95"
+ key_metric_asc: true
+ eval_interval: 10
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-4
+ scheduler:
+ type: LRListScheduler
+ params:
+ milestones: [ 160 ]
+ lr_list: [ [ 1e-5 ] ]
+ snapshots:
+ max_snapshots: 5
+ save_epochs: 25
+ save_optimizer_state: false
+train_settings:
+ batch_size: 1
+ dataloader_workers: 0
+ dataloader_pin_memory: false
+ display_iters: 500
+ epochs: 250
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts.yaml b/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts.yaml
new file mode 100644
index 0000000000..04501208b2
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts.yaml
@@ -0,0 +1,39 @@
+type: HeatmapHead
+weight_init: normal
+predictor:
+ type: HeatmapPredictor
+ apply_sigmoid: false
+ clip_scores: true
+ location_refinement: true
+ locref_std: 7.2801
+target_generator:
+ type: HeatmapGaussianGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ gradient_masking: false
+ generate_locref: true
+ locref_std: 7.2801
+criterion:
+ heatmap:
+ type: WeightedMSECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+heatmap_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_bodyparts"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+locref_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_bodyparts x 2"
+ kernel_size:
+ - 3
+ strides:
+ - 2
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts_with_paf.yaml b/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts_with_paf.yaml
new file mode 100644
index 0000000000..183a8ad260
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/head_bodyparts_with_paf.yaml
@@ -0,0 +1,62 @@
+type: DLCRNetHead
+predictor:
+ type: PartAffinityFieldPredictor
+ num_animals: "num_individuals"
+ num_multibodyparts: "num_bodyparts"
+ num_uniquebodyparts: 0
+ nms_radius: 5
+ sigma: 1.0
+ locref_stdev: 7.2801
+ min_affinity: 0.05
+ graph: "paf_graph"
+ edges_to_keep: "paf_edges_to_keep"
+ apply_sigmoid: true
+ clip_scores: false
+target_generator:
+ type: SequentialGenerator
+ generators:
+ - type: HeatmapPlateauGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ gradient_masking: false
+ generate_locref: true
+ locref_std: 7.2801
+ - type: PartAffinityFieldGenerator
+ graph: "paf_graph"
+ width: 20
+criterion:
+ heatmap:
+ type: WeightedBCECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+ paf:
+ type: WeightedHuberCriterion
+ weight: 0.1
+heatmap_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_bodyparts"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+locref_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_bodyparts x 2"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+paf_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_limbs x 2" # num_limbs = len(graph)
+ kernel_size:
+ - 3
+ strides:
+ - 2
+num_stages: 5
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/head_identity.yaml b/deeplabcut/pose_estimation_pytorch/config/base/head_identity.yaml
new file mode 100644
index 0000000000..eb9c253929
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/head_identity.yaml
@@ -0,0 +1,23 @@
+type: HeatmapHead
+predictor:
+ type: IdentityPredictor
+ apply_sigmoid: true
+target_generator:
+ type: HeatmapPlateauGenerator
+ num_heatmaps: "num_individuals"
+ pos_dist_thresh: 17
+ heatmap_mode: INDIVIDUAL
+ gradient_masking: false
+ generate_locref: false
+criterion:
+ heatmap:
+ type: WeightedBCECriterion
+ weight: 1.0
+heatmap_config:
+ channels:
+ - "backbone_output_channels"
+ - "num_individuals"
+ kernel_size:
+ - 3
+ strides:
+ - 2
diff --git a/deeplabcut/pose_estimation_pytorch/config/base/head_topdown.yaml b/deeplabcut/pose_estimation_pytorch/config/base/head_topdown.yaml
new file mode 100644
index 0000000000..57d5fa483d
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/base/head_topdown.yaml
@@ -0,0 +1,40 @@
+type: HeatmapHead
+weight_init: normal
+predictor:
+ type: HeatmapPredictor
+ apply_sigmoid: false
+ clip_scores: true
+ location_refinement: true
+ locref_std: 7.2801
+target_generator:
+ type: HeatmapGaussianGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ gradient_masking: true
+ background_weight: 0.0
+ generate_locref: true
+ locref_std: 7.2801
+criterion:
+ heatmap:
+ type: WeightedMSECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+heatmap_config:
+ channels:
+ - "backbone_output_channels"
+ kernel_size: []
+ strides: []
+ final_conv:
+ out_channels: "num_bodyparts"
+ kernel_size: 1
+locref_config:
+ channels:
+ - "backbone_output_channels"
+ kernel_size: []
+ strides: []
+ final_conv:
+ out_channels: "num_bodyparts x 2"
+ kernel_size: 1
diff --git a/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w18.yaml b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w18.yaml
new file mode 100644
index 0000000000..8e57ecc53e
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w18.yaml
@@ -0,0 +1,66 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w18
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ interpolate_branches: true
+ increased_channel_count: false
+ backbone_output_channels: 270
+ heads:
+ bodypart:
+ type: DEKRHead
+ weight_init: dekr
+ target_generator:
+ type: DEKRGenerator
+ num_joints: "num_bodyparts"
+ pos_dist_thresh: 17
+ bg_weight: 0.1
+ criterion:
+ heatmap:
+ type: DEKRHeatmapLoss
+ weight: 1
+ offset:
+ type: DEKROffsetLoss
+ weight: 0.03
+ predictor:
+ type: DEKRPredictor
+ apply_sigmoid: false
+ use_heatmap: false
+ clip_scores: true
+ num_animals: "num_individuals"
+ keypoint_score_type: combined
+ max_absorb_distance: 75
+ heatmap_config:
+ channels:
+ - 270
+ - 18
+ - "num_bodyparts + 1" # num_bodyparts + center keypoint
+ num_blocks: 1
+ dilation_rate: 1
+ final_conv_kernel: 1
+ offset_config:
+ channels:
+ - 270
+ - "num_bodyparts x 15" # num_bodyparts * num_offset_per_kpt
+ - "num_bodyparts"
+ num_offset_per_kpt: 15
+ num_blocks: 2
+ dilation_rate: 1
+ final_conv_kernel: 1
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.0005
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
+with_center_keypoints: true
\ No newline at end of file
diff --git a/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w32.yaml b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w32.yaml
new file mode 100644
index 0000000000..d682349bad
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w32.yaml
@@ -0,0 +1,66 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w32
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ interpolate_branches: true
+ increased_channel_count: false
+ backbone_output_channels: 480
+ heads:
+ bodypart:
+ type: DEKRHead
+ weight_init: dekr
+ target_generator:
+ type: DEKRGenerator
+ num_joints: "num_bodyparts"
+ pos_dist_thresh: 17
+ bg_weight: 0.1
+ criterion:
+ heatmap:
+ type: DEKRHeatmapLoss
+ weight: 1
+ offset:
+ type: DEKROffsetLoss
+ weight: 0.03
+ predictor:
+ type: DEKRPredictor
+ apply_sigmoid: false
+ use_heatmap: false
+ clip_scores: true
+ num_animals: "num_individuals"
+ keypoint_score_type: combined
+ max_absorb_distance: 75
+ heatmap_config:
+ channels:
+ - 480
+ - 32
+ - "num_bodyparts + 1" # num_bodyparts + center keypoint
+ num_blocks: 1
+ dilation_rate: 1
+ final_conv_kernel: 1
+ offset_config:
+ channels:
+ - 480
+ - "num_bodyparts x 15" # num_bodyparts * num_offset_per_kpt
+ - "num_bodyparts"
+ num_offset_per_kpt: 15
+ num_blocks: 2
+ dilation_rate: 1
+ final_conv_kernel: 1
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.0005
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
+with_center_keypoints: true
\ No newline at end of file
diff --git a/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w48.yaml b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w48.yaml
new file mode 100644
index 0000000000..4b34f62125
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/dekr/dekr_w48.yaml
@@ -0,0 +1,66 @@
+data:
+ inference:
+ auto_padding: # Required for HRNet backbones
+ pad_width_divisor: 32
+ pad_height_divisor: 32
+model:
+ backbone:
+ type: HRNet
+ model_name: hrnet_w48
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ interpolate_branches: true
+ increased_channel_count: false
+ backbone_output_channels: 720
+ heads:
+ bodypart:
+ type: DEKRHead
+ weight_init: dekr
+ target_generator:
+ type: DEKRGenerator
+ num_joints: "num_bodyparts"
+ pos_dist_thresh: 17
+ bg_weight: 0.1
+ criterion:
+ heatmap:
+ type: DEKRHeatmapLoss
+ weight: 1
+ offset:
+ type: DEKROffsetLoss
+ weight: 0.03
+ predictor:
+ type: DEKRPredictor
+ apply_sigmoid: false
+ use_heatmap: false
+ clip_scores: true
+ num_animals: "num_individuals"
+ keypoint_score_type: combined
+ max_absorb_distance: 75
+ heatmap_config:
+ channels:
+ - 720
+ - 48
+ - "num_bodyparts + 1" # num_bodyparts + center keypoint
+ num_blocks: 1
+ dilation_rate: 1
+ final_conv_kernel: 1
+ offset_config:
+ channels:
+ - 720
+ - "num_bodyparts x 15" # num_bodyparts * num_offset_per_kpt
+ - "num_bodyparts"
+ num_offset_per_kpt: 15
+ num_blocks: 2
+ dilation_rate: 1
+ final_conv_kernel: 1
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 0.0005
+ scheduler:
+ type: LRListScheduler
+ params:
+ lr_list: [ [ 1e-4 ], [ 1e-5 ] ]
+ milestones: [ 90, 120 ]
+with_center_keypoints: true
\ No newline at end of file
diff --git a/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_mobilenet_v3_large_fpn.yaml b/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_mobilenet_v3_large_fpn.yaml
new file mode 100644
index 0000000000..3d07eb41c0
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_mobilenet_v3_large_fpn.yaml
@@ -0,0 +1,5 @@
+model:
+ type: FasterRCNN
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ variant: fasterrcnn_mobilenet_v3_large_fpn
diff --git a/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_resnet50_fpn_v2.yaml b/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_resnet50_fpn_v2.yaml
new file mode 100644
index 0000000000..53711c6810
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/detectors/fasterrcnn_resnet50_fpn_v2.yaml
@@ -0,0 +1,5 @@
+model:
+ type: FasterRCNN
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+ variant: fasterrcnn_resnet50_fpn_v2
diff --git a/deeplabcut/pose_estimation_pytorch/config/detectors/ssdlite.yaml b/deeplabcut/pose_estimation_pytorch/config/detectors/ssdlite.yaml
new file mode 100644
index 0000000000..c0357b34a0
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/detectors/ssdlite.yaml
@@ -0,0 +1,6 @@
+model:
+ type: SSDLite
+ freeze_bn_stats: true
+ freeze_bn_weights: false
+train_settings:
+ batch_size: 16
diff --git a/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride16_ms5.yaml b/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride16_ms5.yaml
new file mode 100644
index 0000000000..6b9036c6ce
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride16_ms5.yaml
@@ -0,0 +1,72 @@
+model:
+ backbone:
+ type: DLCRNet
+ model_name: resnet50
+ pretrained: true
+ output_stride: 16
+ backbone_output_channels: 2304
+ pose_model:
+ stride: 8
+ heads:
+ bodypart:
+ type: DLCRNetHead
+ predictor:
+ type: PartAffinityFieldPredictor
+ num_animals: "num_individuals"
+ num_multibodyparts: "num_bodyparts"
+ num_uniquebodyparts: 0
+ nms_radius: 5
+ sigma: 1.0
+ locref_stdev: 7.2801
+ min_affinity: 0.05
+ graph: "paf_graph"
+ edges_to_keep: "paf_edges_to_keep"
+ target_generator:
+ type: SequentialGenerator
+ generators:
+ - type: HeatmapPlateauGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ generate_locref: true
+ locref_std: 7.2801
+ - type: PartAffinityFieldGenerator
+ graph: "paf_graph"
+ width: 20
+ criterion:
+ heatmap:
+ type: WeightedBCECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+ paf:
+ type: WeightedHuberCriterion
+ weight: 0.1
+ heatmap_config:
+ channels:
+ - 2304
+ - "num_bodyparts"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+ locref_config:
+ channels:
+ - 2304
+ - "num_bodyparts x 2"
+ kernel_size:
+ - 3
+ strides:
+ - 2
+ paf_config:
+ channels:
+ - 2304
+ - "num_limbs x 2" # num_limbs = len(graph)
+ kernel_size:
+ - 3
+ strides:
+ - 2
+ num_stages: 5
+runner:
+ eval_interval: 25 # slow evaluation with poor Part-Affinity fields
diff --git a/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride32_ms5.yaml b/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride32_ms5.yaml
new file mode 100644
index 0000000000..26ec928ee7
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/dlcrnet/dlcrnet_stride32_ms5.yaml
@@ -0,0 +1,81 @@
+model:
+ backbone:
+ type: DLCRNet
+ model_name: resnet50
+ pretrained: true
+ output_stride: 32
+ backbone_output_channels: 2304
+ pose_model:
+ stride: 8
+ heads:
+ bodypart:
+ type: DLCRNetHead
+ predictor:
+ type: PartAffinityFieldPredictor
+ num_animals: "num_individuals"
+ num_multibodyparts: "num_bodyparts"
+ num_uniquebodyparts: 0
+ nms_radius: 5
+ sigma: 1.0
+ locref_stdev: 7.2801
+ min_affinity: 0.05
+ graph: "paf_graph"
+ edges_to_keep: "paf_edges_to_keep"
+ target_generator:
+ type: SequentialGenerator
+ generators:
+ - type: HeatmapPlateauGenerator
+ num_heatmaps: "num_bodyparts"
+ pos_dist_thresh: 17
+ heatmap_mode: KEYPOINT
+ generate_locref: true
+ locref_std: 7.2801
+ - type: PartAffinityFieldGenerator
+ graph: "paf_graph"
+ width: 20
+ criterion:
+ heatmap:
+ type: WeightedBCECriterion
+ weight: 1.0
+ locref:
+ type: WeightedHuberCriterion
+ weight: 0.05
+ paf:
+ type: WeightedHuberCriterion
+ weight: 0.1
+ heatmap_config:
+ channels:
+ - 2304
+ - 1152
+ - "num_bodyparts"
+ kernel_size:
+ - 3
+ - 3
+ strides:
+ - 2
+ - 2
+ locref_config:
+ channels:
+ - 2304
+ - 1152
+ - "num_bodyparts x 2"
+ kernel_size:
+ - 3
+ - 3
+ strides:
+ - 2
+ - 2
+ paf_config:
+ channels:
+ - 2304
+ - 1152
+ - "num_limbs x 2" # num_limbs = len(graph)
+ kernel_size:
+ - 3
+ - 3
+ strides:
+ - 2
+ - 2
+ num_stages: 5
+runner:
+ eval_interval: 25 # slow evaluation with poor Part-Affinity fields
diff --git a/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py b/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
new file mode 100644
index 0000000000..396a431c05
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
@@ -0,0 +1,527 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Methods to create the configuration files for PyTorch DeepLabCut models"""
+from __future__ import annotations
+
+import copy
+from pathlib import Path
+
+from deeplabcut.core.config import read_config_as_dict, write_config
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.config.utils import (
+ get_config_folder_path,
+ load_backbones,
+ load_base_config,
+ replace_default_values,
+ update_config,
+)
+from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal
+
+
+def make_pytorch_pose_config(
+ project_config: dict,
+ pose_config_path: str | Path,
+ net_type: str | None = None,
+ top_down: bool = False,
+ detector_type: str | None = None,
+ weight_init: WeightInitialization | None = None,
+ save: bool = False,
+) -> dict:
+ """Creates a PyTorch pose configuration file for a DeepLabCut project
+
+ The base/ folder contains default configurations, such as data augmentations or
+ heatmap heads (that can be used to predict pose or identity based on visual
+ features). These files are used to create pose model configurations.
+
+ All available backbone configurations are stored in the backbones/ folder.
+ - any backbone can be a single animal model with a heatmap head added on top
+ - any backbone can be a top-down model with a detector and a heatmap head
+ - any backbone can be a bottom-up model with a detector and a heatmap + PAF head
+
+ All other model architectures have their own folders, with different variants
+ available. Top-down model architectures must specify `method: TD` in their
+ configuration files, from which this method adds a backbone configuration.
+
+ Placeholder values (such as `num_bodyparts` or `num_individuals`) are filled in
+ based on the project config file.
+
+ Args:
+ project_config: the DeepLabCut project config
+ pose_config_path: the path where the pytorch pose configuration will be saved
+ net_type: the architecture of the desired pose estimation model
+ top_down: when the net_type is a backbone, whether to create a top-down model
+ by associating a detector to the pose model. Required for multi-animal
+ projects when net_type is a backbone (as a backbone + heatmap head can only
+ predict pose for single individuals).
+ detector_type: for top-down pose models, the architecture of the desired object
+ detection model
+ weight_init: Specify how model weights should be initialized. If None, ImageNet
+ pretrained weights from Timm will be loaded when training.
+ save: Whether to save the model configuration file to the ``pose_config_path``.
+
+ Returns:
+ the PyTorch pose configuration file
+ """
+ multianimal_project = project_config.get("multianimalproject", False)
+ individuals = project_config.get("individuals", ["single"])
+ with_identity = project_config.get("identity")
+ bodyparts = auxiliaryfunctions.get_bodyparts(project_config)
+ unique_bpts = auxiliaryfunctions.get_unique_bodyparts(project_config)
+
+ if net_type is None:
+ net_type = project_config.get("default_net_type", "resnet_50")
+
+ configs_dir = get_config_folder_path()
+ pose_config = load_base_config(configs_dir)
+ pose_config = add_metadata(project_config, pose_config, pose_config_path)
+ pose_config["net_type"] = net_type
+
+ backbones = load_backbones(configs_dir)
+ if net_type in backbones:
+ if not top_down and multianimal_project:
+ model_cfg = create_backbone_with_paf_model(
+ configs_dir=configs_dir,
+ net_type=net_type,
+ num_individuals=len(individuals),
+ bodyparts=bodyparts,
+ paf_parameters=_get_paf_parameters(project_config, bodyparts),
+ )
+ else:
+ model_cfg = create_backbone_with_heatmap_model(
+ configs_dir=configs_dir,
+ net_type=net_type,
+ multianimal_project=multianimal_project,
+ bodyparts=bodyparts,
+ top_down=top_down,
+ )
+ else:
+ architecture = net_type.split("_")[0]
+ default_value_kwargs = {}
+ if architecture == "dlcrnet":
+ default_value_kwargs.update(_get_paf_parameters(project_config, bodyparts))
+
+ cfg_path = configs_dir / architecture / f"{net_type}.yaml"
+ model_cfg = read_config_as_dict(cfg_path)
+ model_cfg = replace_default_values(
+ model_cfg,
+ num_bodyparts=len(bodyparts),
+ num_individuals=len(individuals),
+ **default_value_kwargs,
+ )
+
+ is_top_down = model_cfg.get("method", "BU").upper() == "TD"
+ if is_top_down:
+ model_cfg = add_detector(
+ configs_dir,
+ model_cfg,
+ len(individuals),
+ detector_type=detector_type,
+ )
+
+ # add the default augmentations to the config
+ aug_filename = "aug_top_down.yaml" if is_top_down else "aug_default.yaml"
+ aug_cfg = {"data": read_config_as_dict(configs_dir / "base" / aug_filename)}
+ pose_config = update_config(pose_config, aug_cfg)
+
+ # add the model to the config
+ pose_config = update_config(pose_config, model_cfg)
+
+ # set the dataset from which to load weights
+ if weight_init is not None:
+ pose_config["train_settings"]["weight_init"] = weight_init.to_dict()
+
+ # add a unique bodypart head if needed
+ if len(unique_bpts) > 0:
+ if is_top_down:
+ raise ValueError(
+ f"You selected a top-down model architecture ({net_type}), but you have"
+ f" unique bodyparts, which is not yet implemented for top-down models."
+ " Please select a bottom-up architecture such as `resnet_50` for single"
+ " animal projects or `dlcrnet_50` for multi-animal projects."
+ )
+
+ pose_config = add_unique_bodypart_head(
+ configs_dir,
+ pose_config,
+ num_unique_bodyparts=len(unique_bpts),
+ backbone_output_channels=pose_config["model"]["backbone_output_channels"],
+ )
+
+ # add an identity head if needed
+ if with_identity:
+ if is_top_down:
+ raise ValueError(
+ f"You selected a top-down model architecture ({net_type}), but you have"
+ f" set `identity: true`, which is not yet implemented for top-down"
+ f" models. Please select a bottom-up architecture such as `dlcrnet_50`"
+ f" to train with identity, or set `identity: false`."
+ )
+
+ pose_config = add_identity_head(
+ configs_dir,
+ pose_config,
+ num_individuals=len(individuals),
+ backbone_output_channels=pose_config["model"]["backbone_output_channels"],
+ )
+
+ # sort first-level keys to make it prettier
+ pose_config = dict(sorted(pose_config.items()))
+
+ if save:
+ write_config(pose_config_path, pose_config, overwrite=True)
+
+ return pose_config
+
+
+def make_pytorch_test_config(
+ model_config: dict,
+ test_config_path: str | Path,
+ save: bool = False,
+) -> dict:
+ """Creates the test configuration for a model
+
+ Args:
+ model_config: The PyTorch config for the model.
+ test_config_path: The path of the test config
+ save: Whether to save the test config to ``test_config_path``.
+
+ Returns:
+ The test configuration file.
+ """
+ bodyparts = model_config["metadata"]["bodyparts"]
+ unique_bodyparts = model_config["metadata"]["unique_bodyparts"]
+ all_joint_names = bodyparts + unique_bodyparts
+
+ test_config = dict(
+ dataset=model_config["metadata"]["project_path"],
+ dataset_type="multi-animal-imgaug", # required for downstream tracking
+ num_joints=len(all_joint_names),
+ all_joints=[[i] for i in range(len(all_joint_names))],
+ all_joints_names=all_joint_names,
+ net_type=model_config["net_type"],
+ global_scale=1,
+ scoremap_dir="test",
+ )
+ if save:
+ write_config(test_config_path, test_config)
+
+ return test_config
+
+
+def make_basic_project_config(
+ dataset_path: Path | str,
+ bodyparts: list[str],
+ max_individuals: int,
+ multi_animal: bool = True,
+) -> dict:
+ """Creates a basic configuration dict that can be used to create model configs.
+
+ This should be used to create the `project_config` given to
+ `make_pytorch_pose_config` for non-DeepLabCut projects (e.g. when creating a
+ configuration file for a model that will be trained on a COCO dataset).
+
+ Args:
+ dataset_path: The path to the dataset for which the config will be created.
+ bodyparts: The bodyparts labeled for individuals in the dataset.
+ max_individuals: The maximum number of individuals to detect in a single image.
+ multi_animal: Whether multiple animals can be present in an image.
+
+ Returns:
+ The created project configuration dict that can be given to
+ `make_pytorch_pose_config`.
+
+ Examples:
+ Creating a `pytorch_config` for a ResNet50 backbone with a part-affinity head (
+ as multi_animal=True and top_down=False)
+
+ >>> import deeplabcut.pose_estimation_pytorch as pep
+ >>> project_config = pep.config.make_basic_project_config(
+ >>> dataset_path="/path/coco",
+ >>> bodyparts=["nose", "left_eye", "right_eye"],
+ >>> max_individuals=12,
+ >>> multi_animal=True,
+ >>> )
+ >>> model_config = pep.config.make_pytorch_pose_config(
+ >>> project_config=project_config,
+ >>> pose_config_path="/path/coco/models/resnet50/pytorch_config.yaml",
+ >>> net_type="resnet_50",
+ >>> top_down=False,
+ >>> save=True,
+ >>> )
+
+ Creating a `pytorch_config` for a ResNet50 backbone with a simple heatmap head
+ (as the project is single-animal):
+
+ >>> import deeplabcut.pose_estimation_pytorch as pep
+ >>> project_config = pep.config.make_basic_project_config(
+ >>> dataset_path="/path/coco",
+ >>> bodyparts=["nose", "left_eye", "right_eye"],
+ >>> max_individuals=1,
+ >>> multi_animal=False,
+ >>> )
+ >>> model_config = pep.config.make_pytorch_pose_config(
+ >>> project_config=project_config,
+ >>> pose_config_path="/path/coco/models/resnet50/pytorch_config.yaml",
+ >>> net_type="resnet_50",
+ >>> top_down=False,
+ >>> save=True,
+ >>> )
+ """
+ return dict(
+ project_path=str(dataset_path),
+ multianimalproject=multi_animal,
+ bodyparts=bodyparts,
+ multianimalbodyparts=bodyparts,
+ uniquebodyparts=[],
+ individuals=[f"individual{i:03d}" for i in range(max_individuals)],
+ )
+
+
+def add_metadata(
+ project_config: dict, config: dict, pose_config_path: str | Path
+) -> dict:
+ """Adds metadata to a pytorch pose configuration
+
+ Args:
+ project_config: the project configuration
+ config: the pytorch pose configuration
+ pose_config_path: the path where the pytorch pose configuration will be saved
+
+ Returns:
+ the configuration with a `meta` key added
+ """
+ config = copy.deepcopy(config)
+ config["metadata"] = {
+ "project_path": project_config["project_path"],
+ "pose_config_path": str(pose_config_path),
+ "bodyparts": auxiliaryfunctions.get_bodyparts(project_config),
+ "unique_bodyparts": auxiliaryfunctions.get_unique_bodyparts(project_config),
+ "individuals": project_config.get("individuals", ["animal"]),
+ "with_identity": project_config.get("identity", False),
+ }
+ return config
+
+
+def create_backbone_with_heatmap_model(
+ configs_dir: Path,
+ net_type: str,
+ multianimal_project: bool,
+ bodyparts: list[str],
+ top_down: bool,
+) -> dict:
+ """
+ Creates a simple heatmap pose estimation model, composed of a backbone and a head
+ predicting heatmaps and location refinement maps
+
+ Args:
+ configs_dir: path to the DeepLabCut "configs" directory
+ net_type: the type of backbone to create the model with (e.g., resnet_50)
+ multianimal_project: whether this model is created for a multi-animal project
+ bodyparts: the bodyparts to detect
+ top_down: whether the model will be associated to a detector to form a top-down
+ pose estimation model
+
+ Returns:
+ the backbone + heatmap model configuration
+
+ Raises:
+ ValueError: if the model is being created for a multi-animal project but the
+ head won't be associated with a detector (heatmaps can only predict
+ bodyparts for a single individual).
+ """
+ if multianimal_project and not top_down:
+ raise ValueError(
+ "A pose model formed of a backbone and simple heatmap + location refinement"
+ " head can only be used for single animal projects. As you're working with"
+ " a multi-animal project, please select a multi-individual model instead of"
+ f" {net_type} or use a detector to create a top-down model (create your"
+ f" configuration with `make_pytorch_pose_config(..., top_down=True)`)."
+ )
+
+ # add the backbone to the config
+ model_config = read_config_as_dict(configs_dir / "backbones" / f"{net_type}.yaml")
+ backbone_output_channels = model_config["model"]["backbone_output_channels"]
+
+ model_config["method"] = "bu"
+ bodypart_head_name = "head_bodyparts.yaml"
+ if top_down:
+ model_config["method"] = "td"
+ bodypart_head_name = "head_topdown.yaml"
+
+ # add a bodypart head
+ bodypart_head_config = read_config_as_dict(
+ configs_dir / "base" / bodypart_head_name
+ )
+ model_config["model"]["heads"] = {
+ "bodypart": replace_default_values(
+ bodypart_head_config,
+ num_bodyparts=len(bodyparts),
+ backbone_output_channels=backbone_output_channels,
+ )
+ }
+ return model_config
+
+
+def create_backbone_with_paf_model(
+ configs_dir: Path,
+ net_type: str,
+ num_individuals: int,
+ bodyparts: list[str],
+ paf_parameters: dict,
+) -> dict:
+ """
+ Creates a pose estimation model, composed of a backbone and a head predicting
+ heatmaps, location refinement maps and part affinity fields for multi-animal pose
+ estimation.
+
+ Args:
+ configs_dir: path to the DeepLabCut "configs" directory
+ net_type: the type of backbone to create the model with (e.g., resnet_50)
+ num_individuals: the maximum number of individuals in a frame
+ bodyparts: the bodyparts to detect
+ paf_parameters: the parameters for the PAF
+
+ Returns:
+ the backbone + heatmap, location refinement, PAF model configuration
+ """
+ # add the backbone to the config
+ model_config = read_config_as_dict(configs_dir / "backbones" / f"{net_type}.yaml")
+ backbone_output_channels = model_config["model"]["backbone_output_channels"]
+
+ # add a bodypart head
+ bodypart_head_config = read_config_as_dict(
+ configs_dir / "base" / f"head_bodyparts_with_paf.yaml"
+ )
+ model_config["model"]["heads"] = {
+ "bodypart": replace_default_values(
+ bodypart_head_config,
+ num_bodyparts=len(bodyparts),
+ num_individuals=num_individuals,
+ backbone_output_channels=backbone_output_channels,
+ **paf_parameters,
+ )
+ }
+ return model_config
+
+
+def add_detector(
+ configs_dir: Path,
+ config: dict,
+ num_individuals: int,
+ detector_type: str | None = None,
+) -> dict:
+ """Adds a detector to a model
+
+ Args:
+ configs_dir: path to the DeepLabCut "configs" directory
+ config: model configuration to update
+ num_individuals: the maximum number of individuals the model should detect
+ detector_type: the type of detector to use (if None, uses ``ssdlite``)
+
+ Returns:
+ the model configuration with an added detector config
+ """
+ if detector_type is None:
+ detector_type = "ssdlite" # default detector
+
+ detector_type = detector_type.lower()
+ config = copy.deepcopy(config)
+ detector_config = update_config(
+ read_config_as_dict(configs_dir / "base" / "base_detector.yaml"),
+ read_config_as_dict(configs_dir / "detectors" / f"{detector_type}.yaml"),
+ )
+ detector_config = replace_default_values(
+ detector_config, num_individuals=num_individuals,
+ )
+ config["detector"] = dict(sorted(detector_config.items()))
+ return config
+
+
+def add_unique_bodypart_head(
+ configs_dir: Path,
+ config: dict,
+ num_unique_bodyparts: int,
+ backbone_output_channels: int,
+) -> dict:
+ """Adds a unique bodypart head to a model
+
+ Args:
+ configs_dir: path to the DeepLabCut "configs" directory
+ config: model configuration to update
+ num_unique_bodyparts: the number of unique bodyparts to detect
+ backbone_output_channels: the number of channels output by the model backbone
+
+ Returns:
+ the configuration with an added unique bodypart head
+ """
+ config = copy.deepcopy(config)
+ unique_head_config = replace_default_values(
+ read_config_as_dict(configs_dir / "base" / "head_bodyparts.yaml"),
+ num_bodyparts=num_unique_bodyparts,
+ backbone_output_channels=backbone_output_channels,
+ )
+ unique_head_config["target_generator"]["label_keypoint_key"] = "keypoints_unique"
+ config["model"]["heads"]["unique_bodypart"] = unique_head_config
+ return config
+
+
+def add_identity_head(
+ configs_dir: Path,
+ config: dict,
+ num_individuals: int,
+ backbone_output_channels: int,
+) -> dict:
+ """Adds an identity head to a model
+
+ Args:
+ configs_dir: path to the DeepLabCut "configs" directory
+ config: model configuration to update
+ num_individuals: the number of individuals to re-identify
+ backbone_output_channels: the number of channels output by the model backbone
+
+ Returns:
+ the configuration with an added identity head
+ """
+ config = copy.deepcopy(config)
+ id_head_config = read_config_as_dict(configs_dir / "base" / "head_identity.yaml")
+ config["model"]["heads"]["identity"] = replace_default_values(
+ id_head_config,
+ num_individuals=num_individuals,
+ backbone_output_channels=backbone_output_channels,
+ )
+ return config
+
+
+def _get_paf_parameters(
+ project_config: dict,
+ bodyparts: list[str],
+ num_limbs_threshold: int = 105,
+ paf_graph_degree: int = 6,
+) -> dict:
+ """Gets values for PAF parameters from the project configuration"""
+ paf_graph = [
+ [i, j] for i in range(len(bodyparts)) for j in range(i + 1, len(bodyparts))
+ ]
+ num_limbs = len(paf_graph)
+ # If the graph is unnecessarily large (with 15+ keypoints by default),
+ # we randomly prune it to a size guaranteeing an average node degree of 6;
+ # see Suppl. Fig S9c in Lauer et al., 2022.
+ if num_limbs >= num_limbs_threshold:
+ paf_graph = auxfun_multianimal.prune_paf_graph(
+ paf_graph,
+ average_degree=paf_graph_degree,
+ )
+ num_limbs = len(paf_graph)
+ return {
+ "paf_graph": paf_graph,
+ "num_limbs": num_limbs,
+ "paf_edges_to_keep": project_config.get("paf_best", list(range(num_limbs))),
+ }
diff --git a/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_m.yaml b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_m.yaml
new file mode 100644
index 0000000000..d6bc515f94
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_m.yaml
@@ -0,0 +1,98 @@
+data:
+ inference:
+ top_down_crop:
+ width: 256
+ height: 256
+ train:
+ random_bbox_transform:
+ shift_factor: 0.16
+ shift_prob: 0.3
+ scale_factor: [0.75, 1.25]
+ scale_prob: 1.0
+ p: 1.0
+ top_down_crop:
+ width: 256
+ height: 256
+method: td # Need to add a detector
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_m
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 0.67
+ widen_factor: 0.75
+ backbone_output_channels: 768
+ heads:
+ bodypart:
+ type: RTMCCHead
+ weight_init: RTMPose
+ target_generator:
+ type: SimCCGenerator
+ input_size: [256, 256]
+ smoothing_type: gaussian
+ sigma: [5.66, 5.66]
+ simcc_split_ratio: 2.0
+ label_smooth_weight: 0.0
+ normalize: false
+ criterion:
+ x:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ y:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ predictor:
+ type: SimCCPredictor
+ simcc_split_ratio: 2.0
+ input_size: [256, 256]
+ in_channels: 768
+ out_channels: "num_bodyparts"
+ in_featuremap_size: [8, 8] # input_size / backbone stride
+ simcc_split_ratio: 2.0
+ final_layer_kernel_size: 7
+ gau_cfg:
+ hidden_dims: 256
+ s: 128
+ expansion_factor: 2
+ dropout_rate: 0
+ drop_path: 0.0
+ act_fn: "SiLU"
+ use_rel_bias: false
+ pos_enc: false
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-3
+ scheduler:
+ type: SequentialLR
+ params:
+ schedulers:
+ - type: LinearLR
+ params:
+ start_factor: 0.001
+ end_factor: 1.0
+ total_iters: 5
+ - type: CosineAnnealingLR
+ params:
+ T_max: 200 # max_epochs // 2
+ eta_min: 5e-5 # ~base_lr / 20
+ - type: LRListScheduler
+ params:
+ milestones:
+ - 0
+ lr_list:
+ - - 5e-5
+ milestones:
+ - 200 # max_epochs // 2
+ - 400
+train_settings:
+ batch_size: 32
+ dataloader_workers: 4
+ dataloader_pin_memory: false
+ epochs: 400
diff --git a/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_s.yaml b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_s.yaml
new file mode 100644
index 0000000000..fbc4ff7ed4
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_s.yaml
@@ -0,0 +1,98 @@
+data:
+ inference:
+ top_down_crop:
+ width: 256
+ height: 256
+ train:
+ random_bbox_transform:
+ shift_factor: 0.16
+ shift_prob: 0.3
+ scale_factor: [0.75, 1.25]
+ scale_prob: 1.0
+ p: 1.0
+ top_down_crop:
+ width: 256
+ height: 256
+method: td # Need to add a detector
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_s
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 0.33
+ widen_factor: 0.5
+ backbone_output_channels: 512
+ heads:
+ bodypart:
+ type: RTMCCHead
+ weight_init: RTMPose
+ target_generator:
+ type: SimCCGenerator
+ input_size: [256, 256]
+ smoothing_type: gaussian
+ sigma: [5.66, 5.66]
+ simcc_split_ratio: 2.0
+ label_smooth_weight: 0.0
+ normalize: false
+ criterion:
+ x:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ y:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ predictor:
+ type: SimCCPredictor
+ simcc_split_ratio: 2.0
+ input_size: [256, 256]
+ in_channels: 512
+ out_channels: "num_bodyparts"
+ in_featuremap_size: [8, 8] # input_size / backbone stride
+ simcc_split_ratio: 2.0
+ final_layer_kernel_size: 7
+ gau_cfg:
+ hidden_dims: 256
+ s: 128
+ expansion_factor: 2
+ dropout_rate: 0
+ drop_path: 0.0
+ act_fn: "SiLU"
+ use_rel_bias: false
+ pos_enc: false
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-3
+ scheduler:
+ type: SequentialLR
+ params:
+ schedulers:
+ - type: LinearLR
+ params:
+ start_factor: 0.001
+ end_factor: 1.0
+ total_iters: 5
+ - type: CosineAnnealingLR
+ params:
+ T_max: 200 # max_epochs // 2
+ eta_min: 5e-5 # ~base_lr / 20
+ - type: LRListScheduler
+ params:
+ milestones:
+ - 0
+ lr_list:
+ - - 5e-5
+ milestones:
+ - 200 # max_epochs // 2
+ - 400
+train_settings:
+ batch_size: 32
+ dataloader_workers: 4
+ dataloader_pin_memory: false
+ epochs: 400
diff --git a/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_x.yaml b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_x.yaml
new file mode 100644
index 0000000000..0a49baec75
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/rtmpose/rtmpose_x.yaml
@@ -0,0 +1,98 @@
+data:
+ inference:
+ top_down_crop:
+ width: 384
+ height: 384
+ train:
+ random_bbox_transform:
+ shift_factor: 0.16
+ shift_prob: 0.3
+ scale_factor: [ 0.75, 1.25 ]
+ scale_prob: 1.0
+ p: 1.0
+ top_down_crop:
+ width: 384
+ height: 384
+method: td # Need to add a detector
+model:
+ backbone:
+ type: CSPNeXt
+ model_name: cspnext_x
+ freeze_bn_stats: false
+ freeze_bn_weights: false
+ deepen_factor: 1.33
+ widen_factor: 1.25
+ backbone_output_channels: 1280
+ heads:
+ bodypart:
+ type: RTMCCHead
+ weight_init: RTMPose
+ target_generator:
+ type: SimCCGenerator
+ input_size: [384, 384]
+ smoothing_type: gaussian
+ sigma: [6.93, 6.93]
+ simcc_split_ratio: 2.0
+ label_smooth_weight: 0.0
+ normalize: false
+ criterion:
+ x:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ y:
+ type: KLDiscreteLoss
+ use_target_weight: true
+ beta: 10.0
+ label_softmax: true
+ predictor:
+ type: SimCCPredictor
+ simcc_split_ratio: 2.0
+ input_size: [384, 384]
+ in_channels: 1280
+ out_channels: "num_bodyparts"
+ in_featuremap_size: [12, 12] # input_size / backbone stride
+ simcc_split_ratio: 2.0
+ final_layer_kernel_size: 7
+ gau_cfg:
+ hidden_dims: 256
+ s: 128
+ expansion_factor: 2
+ dropout_rate: 0
+ drop_path: 0.0
+ act_fn: "SiLU"
+ use_rel_bias: false
+ pos_enc: false
+runner:
+ optimizer:
+ type: AdamW
+ params:
+ lr: 1e-3
+ scheduler:
+ type: SequentialLR
+ params:
+ schedulers:
+ - type: LinearLR
+ params:
+ start_factor: 0.001
+ end_factor: 1.0
+ total_iters: 5
+ - type: CosineAnnealingLR
+ params:
+ T_max: 200 # max_epochs // 2
+ eta_min: 5e-5 # ~base_lr / 20
+ - type: LRListScheduler
+ params:
+ milestones:
+ - 0
+ lr_list:
+ - - 5e-5
+ milestones:
+ - 200 # max_epochs // 2
+ - 400
+train_settings:
+ batch_size: 32
+ dataloader_workers: 4
+ dataloader_pin_memory: false
+ epochs: 400
diff --git a/deeplabcut/pose_estimation_pytorch/config/utils.py b/deeplabcut/pose_estimation_pytorch/config/utils.py
new file mode 100644
index 0000000000..80904edad9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/config/utils.py
@@ -0,0 +1,262 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Util functions to create pytorch pose configuration files"""
+from __future__ import annotations
+
+import copy
+from pathlib import Path
+
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.utils import auxiliaryfunctions
+
+
+def replace_default_values(
+ config: dict | list,
+ num_bodyparts: int | None = None,
+ num_individuals: int | None = None,
+ backbone_output_channels: int | None = None,
+ **kwargs,
+) -> dict:
+ """Replaces placeholder values in a model configuration with their actual values.
+
+ This method allows to create template PyTorch configurations for models with values
+ such as "num_bodyparts", which are replaced with the number of bodyparts for a
+ project when making its Pytorch configuration.
+
+ This code can also do some basic arithmetic. You can write "num_bodyparts x 2" (or
+ any factor other than 2) for location refinement channels, and the number of
+ channels will be twice the number of bodyparts. You can write
+ "backbone_output_channels // 2" for the number of channels in a layer, and it will
+ be half the number of channels output by the backbone. You can write
+ "num_bodyparts + 1" (such as for DEKR heatmaps, where a "center" bodypart is added).
+
+ The three base placeholder values that can be computed are "num_bodyparts",
+ "num_individuals" and "backbone_output_channels". You can add more through the
+ keyword arguments (such as "paf_graph": list[tuple[int, int]] or
+ "paf_edges_to_keep": list[int] for DLCRNet models).
+
+ Args:
+ config: the configuration in which to replace default values
+ num_bodyparts: the number of bodyparts
+ num_individuals: the number of individuals
+ backbone_output_channels: the number of backbone output channels
+ kwargs: other placeholder values to fill in
+
+ Returns:
+ the configuration with placeholder values replaced
+
+ Raises:
+ ValueError: if there is a placeholder value who's "updated" value was not
+ given to the method
+ """
+
+ def get_updated_value(variable: str) -> int | list[int]:
+ var_parts = variable.strip().split(" ")
+ var_name = var_parts[0]
+ if updated_values[var_name] is None:
+ raise ValueError(
+ f"Found {variable} in the configuration file, but there is no default "
+ f"value for this variable."
+ )
+
+ if len(var_parts) == 1:
+ return updated_values[var_name]
+ elif len(var_parts) == 3:
+ operator, factor = var_parts[1], var_parts[2]
+ if not factor.isdigit():
+ raise ValueError(f"F must be an integer in variable: {variable}")
+
+ factor = int(factor)
+ if operator == "+":
+ return updated_values[var_name] + factor
+ elif operator == "x":
+ return updated_values[var_name] * factor
+ elif operator == "//":
+ return updated_values[var_name] // factor
+ else:
+ raise ValueError(f"Unknown operator for variable: {variable}")
+
+ raise ValueError(
+ f"Found {variable} in the configuration file, but cannot parse it."
+ )
+
+ updated_values = {
+ "num_bodyparts": num_bodyparts,
+ "num_individuals": num_individuals,
+ "backbone_output_channels": backbone_output_channels,
+ **kwargs,
+ }
+
+ config = copy.deepcopy(config)
+ if isinstance(config, dict):
+ keys_to_update = list(config.keys())
+ elif isinstance(config, list):
+ keys_to_update = range(len(config))
+ else:
+ raise ValueError(f"Config to update must be dict or list, found {type(config)}")
+
+ for k in keys_to_update:
+ if isinstance(config[k], (list, dict)):
+ config[k] = replace_default_values(
+ config[k],
+ num_bodyparts,
+ num_individuals,
+ backbone_output_channels,
+ **kwargs,
+ )
+ elif (
+ isinstance(config[k], str)
+ and config[k].strip().split(" ")[0] in updated_values.keys()
+ ):
+ config[k] = get_updated_value(config[k])
+
+ return config
+
+
+def update_config(config: dict, updates: dict, copy_original: bool = True) -> dict:
+ """Updates items in the configuration file
+
+ The configuration dict should only be composed of primitive Python types
+ (dict, list and values). This is the case when reading the file using
+ `read_config_as_dict`.
+
+ Args:
+ config: the configuration dict to update
+ updates: the updates to make to the configuration dict
+ copy_original: whether to copy the original dict before updating it
+
+ Returns:
+ the updated dictionary
+ """
+ if copy_original:
+ config = copy.deepcopy(config)
+
+ for k, v in updates.items():
+ if k in config and isinstance(config[k], dict) and isinstance(v, dict):
+ if k in ("optimizer", "scheduler") and config["type"] != v["type"]:
+ # if changing the optimizer or scheduler type, update all values
+ config[k] = v
+ else:
+ config[k] = update_config(config[k], v, copy_original=False)
+ else:
+ config[k] = copy.deepcopy(v)
+ return config
+
+
+def update_config_by_dotpath(
+ config: dict, updates: dict, copy_original: bool = True
+) -> dict:
+ """Updates items in the configuration file using dot notation for nested keys
+
+ The configuration dict should only be composed of primitive Python types
+ (dict, list and values). This is the case when reading the file using
+ `read_config_as_dict`.
+
+ Args:
+ config: the configuration dict to update
+ updates: single-level dict with dot notation keys indicating nested paths
+ e.g. {"device": "cuda", "runner.gpus": [0,1]}
+ copy_original: whether to copy the original dict before updating it
+
+ Returns:
+ the updated dictionary
+ """
+ if copy_original:
+ config = copy.deepcopy(config)
+
+ for key, value in updates.items():
+ # Split key into parts by dots
+ parts = key.split(".")
+
+ # Handle non-nested case
+ if len(parts) == 1:
+ config[key] = copy.deepcopy(value)
+ continue
+
+ # Navigate to nested location
+ current = config
+ for part in parts[:-1]:
+ if part not in current:
+ current[part] = {}
+ current = current[part]
+
+ # Set the value at final location
+ current[parts[-1]] = copy.deepcopy(value)
+
+ return config
+
+
+def get_config_folder_path() -> Path:
+ """Returns: the Path to the folder containing the "configs" for DeepLabCut 3.0"""
+ dlc_parent_path = Path(auxiliaryfunctions.get_deeplabcut_path())
+ return dlc_parent_path / "pose_estimation_pytorch" / "config"
+
+
+def load_base_config(config_folder_path: Path) -> dict:
+ """Returns: the base configuration for all PyTorch DeepLabCut models"""
+ base_dir = config_folder_path / "base"
+ base_config = read_config_as_dict(base_dir / "base.yaml")
+ return base_config
+
+
+def load_backbones(configs_dir: Path) -> list[str]:
+ """
+ Args:
+ configs_dir: the Path to the folder containing the "configs" for PyTorch
+ DeepLabCut
+
+ Returns:
+ all backbones with default configurations that can be used
+ """
+ backbone_dir = configs_dir / "backbones"
+ backbones = [p.stem for p in backbone_dir.iterdir() if p.suffix == ".yaml"]
+ return backbones
+
+
+def load_detectors(configs_dir: Path) -> list[str]:
+ """
+ Args:
+ configs_dir: the Path to the folder containing the "configs" for PyTorch
+ DeepLabCut
+
+ Returns:
+ all detectors that are available
+ """
+ detector_dir = configs_dir / "detectors"
+ detectors = [p.stem for p in detector_dir.iterdir() if p.suffix == ".yaml"]
+ return detectors
+
+
+def available_models() -> list[str]:
+ """Returns: the possible variants of models that can be used"""
+ configs_folder_path = get_config_folder_path()
+ backbones = load_backbones(configs_folder_path)
+ models = set()
+ for backbone in backbones:
+ models.add(backbone)
+ models.add("top_down_" + backbone)
+
+ other_architectures = [
+ p
+ for p in configs_folder_path.iterdir()
+ if p.is_dir() and not p.name in ("backbones", "base", "detectors")
+ ]
+ for folder in other_architectures:
+ variants = [p.stem for p in folder.iterdir() if p.suffix == ".yaml"]
+ for variant in variants:
+ models.add(variant)
+
+ return list(sorted(models))
+
+
+def available_detectors() -> list[str]:
+ """Returns: all the possible detectors that can be used"""
+ return load_detectors(get_config_folder_path())
diff --git a/deeplabcut/pose_estimation_pytorch/data/__init__.py b/deeplabcut/pose_estimation_pytorch/data/__init__.py
new file mode 100644
index 0000000000..e4ae41d5b1
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/__init__.py
@@ -0,0 +1,31 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.data.base import Loader
+from deeplabcut.pose_estimation_pytorch.data.cocoloader import COCOLoader
+from deeplabcut.pose_estimation_pytorch.data.collate import COLLATE_FUNCTIONS
+from deeplabcut.pose_estimation_pytorch.data.dlcloader import DLCLoader
+from deeplabcut.pose_estimation_pytorch.data.dataset import (
+ PoseDatasetParameters,
+ PoseDataset,
+)
+from deeplabcut.pose_estimation_pytorch.data.image import top_down_crop
+from deeplabcut.pose_estimation_pytorch.data.postprocessor import (
+ build_bottom_up_postprocessor,
+ build_detector_postprocessor,
+ build_top_down_postprocessor,
+ Postprocessor,
+)
+from deeplabcut.pose_estimation_pytorch.data.preprocessor import (
+ build_bottom_up_preprocessor,
+ build_top_down_preprocessor,
+ Preprocessor,
+)
+from deeplabcut.pose_estimation_pytorch.data.transforms import build_transforms
diff --git a/deeplabcut/pose_estimation_pytorch/data/base.py b/deeplabcut/pose_estimation_pytorch/data/base.py
new file mode 100644
index 0000000000..cb54b1d97c
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/base.py
@@ -0,0 +1,329 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+
+import albumentations as A
+import numpy as np
+
+import deeplabcut.core.config as config_utils
+import deeplabcut.pose_estimation_pytorch.config as config
+from deeplabcut.pose_estimation_pytorch.data.dataset import (
+ PoseDataset,
+ PoseDatasetParameters,
+)
+from deeplabcut.pose_estimation_pytorch.data.utils import (
+ _compute_crop_bounds,
+ bbox_from_keypoints,
+ map_id_to_annotations,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+class Loader(ABC):
+ """
+ Abstract class that represents a blueprint for loading and processing dataset information.
+
+ Methods:
+ load_data(mode: str = 'train') -> dict:
+ Abstract method to convert the project configuration to a standard COCO format.
+ create_dataset(images: dict = None, annotations: dict = None, transform: object = None,
+ mode: str = "train", task: Task = Task.BOTTOM_UP) -> PoseDataset:
+ Creates and returns a PoseDataset given a set of images, annotations, and other parameters.
+ _compute_bboxes(images, annotations, method: str = 'gt') -> dict:
+ Retrieves all bounding boxes based on the specified method.
+ get_dataset_parameters(*args, **kwargs) -> dict:
+ Returns a dictionary containing dataset parameters derived from the configuration.
+ """
+
+ def __init__(self, model_config_path: str | Path) -> None:
+ self.model_config_path = Path(model_config_path)
+ self.model_cfg = config_utils.read_config_as_dict(str(model_config_path))
+ self.pose_task = Task(self.model_cfg["method"])
+ self._loaded_data: dict[str, dict[str, list[dict]]] = {}
+
+ @property
+ def model_folder(self) -> Path:
+ """Returns: The path of the folder containing the model data"""
+ return self.model_config_path.parent
+
+ def update_model_cfg(self, updates: dict) -> None:
+ """Updates the model configuration
+
+ Args:
+ updates: the items to update in the model configuration
+ """
+ self.model_cfg = config.update_config_by_dotpath(self.model_cfg, updates)
+ config_utils.write_config(self.model_config_path, self.model_cfg)
+
+ @abstractmethod
+ def load_data(self, mode: str = "train") -> dict[str, list[dict]]:
+ """Abstract method to convert the project configuration to a standard coco format.
+
+ Raises:
+ NotImplementedError: This method must be implemented in the derived classes.
+ """
+ raise NotImplementedError
+
+ def image_filenames(self, mode: str = "train") -> list[str]:
+ """
+ Args:
+ mode: {"train", "test"} whether to load train or test data
+
+ Returns:
+ the image paths for this mode
+ """
+ if mode not in self._loaded_data:
+ self._loaded_data[mode] = self.load_data(mode)
+
+ data = self._loaded_data[mode]
+ return [image["file_name"] for image in data["images"]]
+
+ def ground_truth_keypoints(
+ self, mode: str = "train", unique_bodypart: bool = False
+ ) -> dict[str, np.ndarray]:
+ """
+ Creates a dictionary containing the ground truth data
+
+ TODO: make more efficient
+
+ Args:
+ mode: {"train", "test"} whether to load train or test data
+ unique_bodypart: returns the ground truth for unique bodyparts
+
+ Raises:
+ ValueError if unique_bodypart=True but there are no unique bodyparts
+
+ Returns:
+ A dict mapping image paths to the ground truth annotations for the mode in
+ the format:
+ {'image': keypoints with shape (num_individuals, num_keypoints, 2)}
+ """
+ parameters = self.get_dataset_parameters()
+ if unique_bodypart:
+ if not parameters.num_unique_bpts > 0:
+ raise ValueError("There are no unique bodyparts in this dataset!")
+ individuals = ["single"]
+ num_bodyparts = parameters.num_unique_bpts
+ else:
+ individuals = parameters.individuals
+ num_bodyparts = parameters.num_joints
+
+ if "weight_init" in self.model_cfg["train_settings"]:
+ weight_init_cfg = self.model_cfg["train_settings"]["weight_init"]
+ if weight_init_cfg["memory_replay"]:
+ conversion_array = weight_init_cfg["conversion_array"]
+ num_bodyparts = len(conversion_array)
+
+ if mode not in self._loaded_data:
+ self._loaded_data[mode] = self.load_data(mode)
+ data = self._loaded_data[mode]
+
+ annotations = self.filter_annotations(data["annotations"], task=Task.BOTTOM_UP)
+ img_to_ann_map = map_id_to_annotations(annotations)
+
+ ground_truth_dict = {}
+ for image in data["images"]:
+ image_path = image["file_name"]
+ individual_keypoints = {
+ annotations[i]["individual"]: annotations[i]["keypoints"]
+ for i in img_to_ann_map[image["id"]]
+ }
+ gt_array = np.zeros((len(individuals), num_bodyparts, 3))
+ # Keep the shape of the ground truth
+ for idv_idx, idv in enumerate(individuals):
+ if idv in individual_keypoints:
+ keypoints = individual_keypoints[idv].reshape(num_bodyparts, -1)
+ gt_array[idv_idx, :, :] = keypoints[:, :3]
+
+ ground_truth_dict[image_path] = gt_array
+
+ return ground_truth_dict
+
+ def ground_truth_bboxes(self, mode: str = "train") -> dict[str, dict]:
+ """Creates a dictionary containing the ground truth bounding boxes
+
+ Args:
+ mode: {"train", "test"} whether to load train or test data
+
+ Returns:
+ A dict mapping image paths to the ground truth annotations for the mode in
+ the format:
+ {
+ 'path/to/image000.png': {
+ "width": (int) the width of the image, in pixels
+ "height": (int) the height of the image, in pixels
+ "bboxes": (np.ndarray) bboxes with shape (num_individuals, xywh)
+ },
+ 'path/to/image000.png': {...},
+ }
+ """
+ if mode not in self._loaded_data:
+ self._loaded_data[mode] = self.load_data(mode)
+ data = self._loaded_data[mode]
+
+ annotations = self.filter_annotations(data["annotations"], task=Task.DETECT)
+ img_to_ann_map = map_id_to_annotations(annotations)
+
+ ground_truth_dict = {}
+ for image in data["images"]:
+ image_path = image["file_name"]
+ img_shape = image["height"], image["width"], 3
+ bboxes = [annotations[i]["bbox"] for i in img_to_ann_map[image["id"]]]
+ if len(bboxes) == 0:
+ bboxes = np.zeros((0, 4))
+ else:
+ bboxes = _compute_crop_bounds(np.stack(bboxes, axis=0), img_shape)
+
+ ground_truth_dict[image_path] = dict(
+ width=image["width"],
+ height=image["height"],
+ bboxes=bboxes,
+ )
+
+ return ground_truth_dict
+
+ def create_dataset(
+ self,
+ transform: A.BaseCompose | None = None,
+ mode: str = "train",
+ task: Task = Task.BOTTOM_UP,
+ ) -> PoseDataset:
+ """
+ Creates a PoseDataset based on provided arguments.
+
+ Args:
+ transform: Transformation to be applied on dataset. Defaults to None.
+ mode: Mode in which dataset is to be used (e.g., 'train', 'test'). Defaults to 'train'.
+ task: Task for which the dataset is being used. Defaults to 'BU'.
+
+ Returns:
+ PoseDataset: An instance of the PoseDataset class.
+
+ Raises:
+ Any exception raised by `get_dataset_parameters` or `load_data` methods.
+ """
+ parameters = self.get_dataset_parameters()
+ data = self.load_data(mode)
+ data["annotations"] = self.filter_annotations(data["annotations"], task)
+ dataset = PoseDataset(
+ images=data["images"],
+ annotations=data["annotations"],
+ transform=transform,
+ mode=mode,
+ task=task,
+ parameters=parameters,
+ )
+ return dataset
+
+ @abstractmethod
+ def get_dataset_parameters(self) -> PoseDatasetParameters:
+ """
+ Retrieves dataset parameters based on the instance's configuration.
+
+ Returns:
+ An instance of the PoseDatasetParameters with the parameters set.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def filter_annotations(annotations: list[dict], task: Task) -> list[dict]:
+ """Filters annotations based on the task, removing empty annotations
+
+ For pose estimation tasks, annotations with empty keypoints are removed. For
+ detection task, annotations with no bounding boxes are removed
+
+ Args:
+ annotations: the annotations to filter
+ task: the task for which to filter
+
+ Returns:
+ list: the filtered annotations
+ """
+ filtered_annotations = []
+ for annotation in annotations:
+ keypoints = annotation["keypoints"].reshape(-1, 3)
+ if task in (Task.DETECT, Task.TOP_DOWN) and (
+ annotation["bbox"][2] <= 0 or annotation["bbox"][3] <= 0
+ ):
+ continue
+ elif task != Task.DETECT and np.all(keypoints[:, :2] <= 0):
+ continue
+
+ filtered_annotations.append(annotation)
+
+ return filtered_annotations
+
+ @staticmethod
+ def _compute_bboxes(
+ images: list[dict],
+ annotations: list[dict],
+ method: str = "gt",
+ ):
+ """TODO: Nastya method of bbox computation (detection bbox, seg. mask, ...)
+ Retrieves all bounding boxes based on the given method.
+
+ Args:
+ images: A list of images.
+ annotations: A list of annotations corresponding to images.
+ method (str, optional): Method to use for retrieving bounding boxes. Defaults to 'gt'.
+ - 'gt': Ground truth bounding boxes.
+ - 'detection bbox': Bounding boxes from detection.
+ - 'keypoints': Bounding boxes from keypoints.
+ - 'segmentation mask': Bounding boxes from segmentation masks.
+
+ Returns:
+ list: Updated annotations based on the given method.
+
+ Raises:
+ ValueError: If 'bbox' is not found in annotation when method is 'gt'.
+ ValueError: If method is not one of 'gt', 'detection bbox', 'keypoints', or 'segmentation mask'.
+ """
+
+ if not method:
+ return annotations
+
+ elif method == "gt":
+ for i, annotation in enumerate(annotations):
+ if "bbox" not in annotation:
+ # or do something else?
+ raise ValueError(
+ f"Bounding box not found in annotation {annotation}, please "
+ "chose another bbox computation method"
+ )
+ return annotations
+
+ elif method == "detection bbox":
+ raise NotImplementedError
+
+ elif method == "keypoints":
+ bbox_margin = 20 # TODO: should not be hardcoded
+ min_area = 1 # TODO: should not be hardcoded
+ img_id_to_annotations = map_id_to_annotations(annotations)
+ for img in images:
+ anns = [annotations[idx] for idx in img_id_to_annotations[img["id"]]]
+ for a in anns:
+ a["bbox"] = bbox_from_keypoints(
+ keypoints=a["keypoints"],
+ image_h=img["height"],
+ image_w=img["width"],
+ margin=bbox_margin,
+ )
+ a["area"] = max(min_area, (a["bbox"][2] * a["bbox"][3]).item())
+ return annotations
+
+ elif method == "segmentation mask":
+ raise NotImplementedError
+
+ else:
+ raise ValueError(f"Unknown method: {method}")
diff --git a/deeplabcut/pose_estimation_pytorch/data/cocoloader.py b/deeplabcut/pose_estimation_pytorch/data/cocoloader.py
new file mode 100644
index 0000000000..0fa4f872bf
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/cocoloader.py
@@ -0,0 +1,355 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import json
+import os
+import warnings
+from pathlib import Path
+
+import numpy as np
+
+from deeplabcut.pose_estimation_pytorch.data.base import Loader
+from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters
+from deeplabcut.pose_estimation_pytorch.data.utils import (
+ map_id_to_annotations,
+ map_image_path_to_id,
+)
+
+
+class COCOLoader(Loader):
+ """
+ Attributes:
+ project_root: root directory path of the COCO project.
+ train_json_filename: the name of the json file containing the train annotations
+ test_json_filename: the name of the json file containing the train annotations.
+ None if there is no test set.
+
+ Examples:
+ loader = COCOLoader(
+ project_root='/path/to/project/',
+ model_config_path='/path/to/project/experiments/train/pytorch_config.yaml'
+ train_json_filename="train.json",
+ test_json_filename="test.json",
+ )
+ """
+
+ def __init__(
+ self,
+ project_root: str | Path,
+ model_config_path: str | Path,
+ train_json_filename: str = "train.json",
+ test_json_filename: str = "test.json",
+ ):
+ super().__init__(Path(model_config_path))
+ self.project_root = Path(project_root)
+ self.train_json_filename = train_json_filename
+ self.test_json_filename = test_json_filename
+ self._dataset_parameters = None
+
+ self.train_json = self.load_json(self.project_root, self.train_json_filename)
+ self.test_json = None
+ if self.test_json_filename:
+ self.test_json = self.load_json(self.project_root, self.test_json_filename)
+
+ def get_dataset_parameters(self) -> PoseDatasetParameters:
+ """
+ Retrieves dataset parameters based on the instance's configuration.
+
+ Returns:
+ An instance of the PoseDatasetParameters with the parameters set.
+ """
+ if self._dataset_parameters is None:
+ num_individuals, bodyparts = self.get_project_parameters(self.train_json)
+
+ crop_cfg = self.model_cfg["data"]["train"].get("top_down_crop", {})
+ crop_w, crop_h = crop_cfg.get("width", 256), crop_cfg.get("height", 256)
+ crop_margin = crop_cfg.get("margin", 0)
+
+ self._dataset_parameters = PoseDatasetParameters(
+ bodyparts=bodyparts,
+ unique_bpts=[],
+ individuals=[f"individual{i}" for i in range(num_individuals)],
+ with_center_keypoints=self.model_cfg.get("with_center_keypoints", False),
+ color_mode=self.model_cfg.get("color_mode", "RGB"),
+ top_down_crop_size=(crop_w, crop_h),
+ top_down_crop_margin=crop_margin,
+ )
+
+ return self._dataset_parameters
+
+ @staticmethod
+ def load_json(project_root: str | Path, filename: str) -> dict:
+ """Load a JSON file from the annotations directory.
+
+ Args:
+ project_root: path to the root directory for the project
+ filename: filename of JSON file to load
+
+ Returns:
+ json_obj: JSON object loaded from the file
+
+ Raises:
+ FileNotFoundError if the file does not exist
+ ValueError if the object stored in the file is not a dict
+
+ Examples:
+ Check https://docs.trainingdata.io/v1.0/Export%20Format/COCO/ to see
+ examples of how a json file looks like.
+ """
+ json_path = os.path.join(project_root, "annotations", filename)
+ if not os.path.exists(json_path):
+ raise FileNotFoundError(f"File {json_path} does not exist.")
+
+ with open(json_path, "r") as f:
+ json_obj = json.load(f)
+
+ if not isinstance(json_obj, dict):
+ raise ValueError("COCO datasets need to be saved in JSON Objects")
+
+ return json_obj
+
+ @staticmethod
+ def validate_categories(coco_json: dict) -> dict:
+ """Checks that the categories for the COCO project are valid.
+
+ Checks that there is no category with ID 0 in the dataset, as this causes issues
+ with torchvision object detectors (label 0 is reserved for background
+ detections). If that's the case, all category IDs are shifted by 1 such that
+ there is no longer a category 0.
+
+ Currently, detectors can only be trained with a single category. This also
+ ensures that all annotations have `category_id` set to 1.
+
+ Args:
+ coco_json: the COCO dictionary containing the annotations
+
+ Returns:
+ the validated COCO object
+ """
+ cat_0 = False
+ for cat in coco_json["categories"]:
+ if cat["id"] == 0:
+ cat_0 = cat
+ warnings.warn(
+ f"Found a category with ID 0 ({cat}) in the COCO dataset. This is not"
+ f" allowed, as category ID 0 is reserved as the background ID for"
+ f" torchvision detectors. All category IDs have been shifted by 1."
+ )
+
+ if len(coco_json["categories"]) > 1:
+ warnings.warn(
+ f"Found more than 1 category in the project. This is currently not"
+ f" supported in DeepLabCut. All annotations will be given category 1"
+ )
+
+ if cat_0:
+ for cat in coco_json["categories"]:
+ cat["id"] = 1
+
+ if cat_0 or len(coco_json["categories"]) > 1:
+ for ann in coco_json["annotations"]:
+ ann["category_id"] = 1
+
+ return coco_json
+
+ @staticmethod
+ def validate_images(project_root: str | Path, coco_json: dict) -> dict:
+ """Goes over images and annotations to look for potential errors
+
+ This code tries to ensure that training a model on this project does not crash
+ down the line
+
+ Completes relative image filepaths to '/project_root/images/file_name'. Absolute
+ filepaths are not updated (which allows storing images to be stored in a folder
+ other than the project root) Then checks that all images files exist in the file
+ system.
+
+ Args:
+ project_root: the root path of the COCO project
+ coco_json: the COCO dictionary containing the annotations
+
+ Returns:
+ the validated COCO object
+ """
+ image_ids = set()
+ missing_images = {}
+ validated_images = []
+ for image in coco_json["images"]:
+ image_filename = Path(image["file_name"])
+ if image_filename.is_absolute():
+ image_path = image_filename
+ else:
+ image_path = Path(project_root) / "images" / image["file_name"]
+ image["file_name"] = str(image_path)
+
+ if not image_path.exists():
+ missing_images[image["id"]] = image["file_name"]
+ else:
+ validated_images.append(image)
+ image_ids.add(image["id"])
+
+ if len(missing_images) > 0:
+ warnings.warn(
+ f"There are {len(missing_images)} images that cannot be found (here"
+ " are some):"
+ )
+ for img_id, file_name in missing_images.items():
+ print(f" * {img_id}: {file_name}")
+
+ coco_json["images"] = validated_images
+
+ if len(missing_images) > 0:
+ validated_annotations = []
+ for ann in coco_json["annotations"]:
+ if ann["image_id"] not in missing_images:
+ validated_annotations.append(ann)
+
+ coco_json["annotations"] = validated_annotations
+
+ validated_annotations = []
+ for ann in coco_json["annotations"]:
+ if ann["image_id"] in image_ids:
+ validated_annotations.append(ann)
+
+ if len(coco_json["annotations"]) < len(validated_annotations):
+ warnings.warn(
+ f"Found some annotations for which the image ID was not in the images."
+ f" Removing them from the dataset."
+ )
+ print(f" All annotations: {len(coco_json['annotations'])}")
+ print(f" Annotations with correct image IDs: {len(validated_annotations)}")
+ coco_json["annotations"] = validated_annotations
+
+ return coco_json
+
+ def load_data(self, mode: str = "train") -> dict:
+ """Convert data from JSON object to dictionary.
+ Args:
+ mode: indicates which JSON object to convert. Defaults to "train".
+
+ Returns:
+ the train or test data
+ """
+ # todo: add validation
+ if mode == "train":
+ data = self.train_json
+ elif mode == "test":
+ data = self.test_json
+ else:
+ raise AttributeError(f"Unknown mode: {mode}")
+
+ data = COCOLoader.validate_categories(data)
+ data = COCOLoader.validate_images(self.project_root, data)
+
+ annotations_per_image = {}
+ for annotation in data["annotations"]:
+ annotation["keypoints"] = np.array(annotation["keypoints"], dtype=float)
+ annotation["bbox"] = np.array(annotation["bbox"], dtype=float)
+
+ # set individual index
+ image_id = annotation["image_id"]
+ individual_idx = annotations_per_image.get(image_id, 0)
+ annotation["individual"] = f"individual{individual_idx}"
+ annotations_per_image[image_id] = individual_idx + 1
+
+ filter_annotations = []
+ for annotation in data['annotations']:
+ keypoints = annotation['keypoints']
+ bbox = annotation['bbox']
+ if np.all(keypoints <= 0) or len(bbox) == 0:
+ continue
+ filter_annotations.append(annotation)
+
+ data["annotations"] = filter_annotations
+
+ # FIXME: why estimating bbox when there are already bbox?
+ annotations_with_bbox = self._compute_bboxes(
+ data["images"],
+ data["annotations"],
+ method="gt",
+ )
+ data["annotations"] = annotations_with_bbox
+ return data
+
+ @staticmethod
+ def get_project_parameters(train_json: dict) -> tuple[int, list[str]]:
+ """
+ Loads the parameters for the project from the train json file
+ TODO: Should this compute the number also using the test json?
+
+ Args:
+ train_json: the json dictionary containing the data for training
+
+ Returns:
+ int: the maximum number of individuals in a single image
+ list[str]: the name of keypoints annotated in this project
+ """
+ # TODO: Check that there's a single category
+ bodyparts = train_json["categories"][0]["keypoints"]
+
+ img_to_annotations = map_id_to_annotations(train_json["annotations"])
+ if len(img_to_annotations) == 0:
+ raise ValueError(f"No images found in the dataset: {train_json}!")
+ elif len(img_to_annotations) == 1:
+ num_individuals = len(list(img_to_annotations.values())[0])
+ else:
+ num_individuals = max(
+ *[len(a_ids) for a_ids in img_to_annotations.values()]
+ )
+
+ return num_individuals, bodyparts
+
+ def predictions_to_coco(
+ self,
+ predictions: dict[str, dict[str, np.ndarray]],
+ mode: str = "train",
+ ) -> list[dict]:
+ """Converts detections to COCO format
+
+ Args:
+ predictions: a dictionary mapping image name to the predictions made for it
+ mode: {"train", "test"} the mode that the predictions were made with
+
+ Returns:
+ The COCO-format predictions
+ """
+ data = self.load_data(mode)
+ image_path_to_id = map_image_path_to_id(data["images"])
+
+ # TODO: no unique bodyparts for COCO
+ coco_predictions = []
+ for image_path, pred in predictions.items():
+ image_id = image_path_to_id[image_path]
+
+ # Shape (num_individuals, num_keypoints, 3)
+ individuals = pred["bodyparts"]
+ for idx, keypoints in enumerate(individuals):
+ if not np.all(keypoints == -1):
+ score = np.mean(keypoints[:, 2]).item()
+ keypoints = keypoints.copy()
+ keypoints[:, 2] = 2 # set visibility instead of score
+ coco_pred = {
+ "image_id": int(image_id),
+ "category_id": 1, # TODO: get category ID from prediction?
+ "keypoints": keypoints.reshape(-1).tolist(),
+ "score": float(score),
+ }
+ if "bboxes" in pred:
+ coco_pred["bbox"] = pred["bboxes"][idx].reshape(-1).tolist()
+ if "bbox_scores" in pred:
+ coco_pred["bbox_scores"] = (
+ pred["bbox_scores"][idx].reshape(-1).tolist()
+ )
+
+ coco_predictions.append(coco_pred)
+
+ return coco_predictions
diff --git a/deeplabcut/pose_estimation_pytorch/data/collate.py b/deeplabcut/pose_estimation_pytorch/data/collate.py
new file mode 100644
index 0000000000..701075ee53
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/collate.py
@@ -0,0 +1,191 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Custom collate functions"""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+from torch.utils.data import default_collate
+
+from deeplabcut.pose_estimation_pytorch.data.image import resize_and_random_crop
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+
+COLLATE_FUNCTIONS = Registry("collate_functions", build_func=build_from_cfg)
+
+
+class CollateFunction(ABC):
+ """A class that can be called as a collate function"""
+
+ @abstractmethod
+ def __call__(self, batch) -> dict | list:
+ """Returns: the collated batch"""
+ raise NotImplementedError()
+
+
+class ResizeCollate(CollateFunction, ABC):
+ """A collate function which resizes all images in a batch to the same size
+
+ Args:
+ max_shift: The maximum shift, in pixels, to add to the random crop (this means
+ there can be a slight border around the image)
+ max_size: The maximum size of the long edge of the image when resized. If the
+ longest side will be greater than this value, resizes such that the longest
+ side is this size, and the shortest side is smaller than the desired size.
+ This is useful to keep some information from images with extreme aspect
+ ratios.
+ seed: The random seed to use to sample scales/sizes.
+ """
+
+ def __init__(
+ self,
+ max_shift: int = 10,
+ max_size: int = 2048,
+ seed: int = 0,
+ ) -> None:
+ self.generator = np.random.default_rng(seed=seed)
+ self.max_size = max_size
+ self.max_shift = max_shift
+ self._current_batch = []
+
+ @abstractmethod
+ def _sample_scale(self) -> int | tuple[int, int]:
+ """Returns: the target shape for images in the batch"""
+ raise NotImplementedError()
+
+ def __call__(self, batch) -> dict | list:
+ """Returns: the collated batch"""
+ self._current_batch = batch
+ new_size = self._sample_scale()
+ updated_batch = []
+ for item in batch:
+ image, new_targets = resize_and_random_crop(
+ image=item["image"],
+ targets=item,
+ size=new_size,
+ max_size=self.max_size,
+ max_shift=self.max_shift,
+ )
+ new_targets["image"] = image
+ updated_batch.append(new_targets)
+
+ return default_collate(updated_batch)
+
+
+@COLLATE_FUNCTIONS.register_module
+class ResizeFromDataSizeCollate(ResizeCollate):
+ """A collate function which resizes all images in a batch to the same size
+
+ The target size is obtained by taking the size of the first image in the batch, and
+ multiplying it by a scale taken uniformly at random from (min_scale, max_scale).
+
+ The aspect ratio of all images in the batch is preserved, with cropping/padding used
+ to generate images of the correct shapes.
+
+ If to_square:
+ The images will be resized to squares, where the side is the short side of the
+ original image.
+ else:
+ The images will be resized to a scaled version of the shape of the first image.
+
+ Args:
+ min_scale: The minimum scale factor to apply to the image size
+ max_scale: The maximum scale factor to apply to the image size
+ min_short_side: The smallest size for the target short side.
+ max_short_side: The largest size for the target short side.
+ max_ratio: The largest aspect ratio allowed for a target (longSide / shortSide).
+ If the aspect ratio is larger, it will be clamped to max_ratio. Must be >=1.
+ multiple_of: If defined, the height and width of all target sizes will be a
+ multiple of this value.
+ to_square: Whether images should be resized to squares.
+ """
+
+ def __init__(
+ self,
+ min_scale: float,
+ max_scale: float,
+ min_short_side: int = 128,
+ max_short_side: int = 1152,
+ max_ratio: float = 2.0,
+ multiple_of: int | None = None,
+ to_square: bool = False,
+ **kwargs
+ ) -> None:
+ super().__init__(**kwargs)
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.min_short_side = min_short_side
+ self.max_short_side = max_short_side
+ self.max_ratio = max_ratio
+ self.multiple_of = multiple_of
+ self.to_square = to_square
+
+ def _sample_scale(self) -> int | tuple[int, int]:
+ if len(self._current_batch) == 0:
+ raise ValueError("Cannot sample frame shape: no items in current batch")
+
+ h, w = self._current_batch[0]["image"].shape[1:]
+ scale = self.generator.uniform(self.min_scale, self.max_scale)
+ if self.to_square:
+ short_side = min(h, w)
+ size = int(round(
+ min(self.max_short_side, max(self.min_short_side, scale * short_side))
+ ))
+ if self.multiple_of is not None:
+ size = _to_multiple(size, self.multiple_of)
+ return size
+
+ short, long = min(h, w), max(h, w)
+ ratio = long / short
+ if ratio > self.max_ratio:
+ ratio = self.max_ratio
+
+ short_size = int(
+ round(min(self.max_short_side, max(self.min_short_side, scale * short)))
+ )
+ if h < w:
+ h = short_size
+ w = int(ratio * short_size)
+ else:
+ h = int(ratio * short_size)
+ w = short_size
+
+ if self.multiple_of is not None:
+ w = _to_multiple(w, self.multiple_of)
+ h = _to_multiple(h, self.multiple_of)
+
+ return h, w
+
+
+@COLLATE_FUNCTIONS.register_module
+class ResizeFromListCollate(ResizeCollate):
+ """A collate function which resizes all images in a batch to the same size
+
+ The target size image size is sampled from a list. If it's a list of integers,
+ all images will be resized into squares. If it's a list of tuples, that will be the
+ target (h, w) for images.
+
+ Args:
+ scales: The target sizes to resize the images to.
+ """
+
+ def __init__(self, scales: list[int] | list[tuple[int, int]], **kwargs) -> None:
+ super().__init__(**kwargs)
+ self.scales = scales
+
+ def _sample_scale(self) -> int | tuple[int, int]:
+ return self.generator.choice(self.scales)
+
+
+def _to_multiple(value: int, of: int) -> int:
+ """Returns: the smallest integer >= ``value`` which is a multiple of ``of``"""
+ return of * ((value + of - 1) // of)
diff --git a/deeplabcut/pose_estimation_pytorch/data/dataset.py b/deeplabcut/pose_estimation_pytorch/data/dataset.py
new file mode 100644
index 0000000000..3a9333bc8a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/dataset.py
@@ -0,0 +1,450 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import albumentations as A
+import cv2
+import numpy as np
+from torch.utils.data import Dataset
+
+from deeplabcut.pose_estimation_pytorch.data.image import top_down_crop
+from deeplabcut.pose_estimation_pytorch.data.utils import (
+ _crop_image_keypoints,
+ _extract_keypoints_and_bboxes,
+ apply_transform,
+ map_id_to_annotations,
+ map_image_path_to_id,
+ out_of_bounds_keypoints,
+ pad_to_length,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+@dataclass(frozen=True)
+class PoseDatasetParameters:
+ """Parameters for a pose dataset
+
+ Attributes:
+ bodyparts: the names of bodyparts in the dataset
+ unique_bpts: the names of unique bodyparts, or an empty list
+ individuals: the names of individuals
+ with_center_keypoints: whether to compute center keypoints for individuals
+ color_mode: {"RGB", "BGR"} the mode to load images in
+ top_down_crop_size: for top-down models, the (width, height) to crop bboxes to
+ top_down_crop_margin: for top-down models, the margin to add around bboxes
+ """
+
+ bodyparts: list[str]
+ unique_bpts: list[str]
+ individuals: list[str]
+ with_center_keypoints: bool = False
+ color_mode: str = "RGB"
+ top_down_crop_size: tuple[int, int] | None = None
+ top_down_crop_margin: int | None = None
+
+ @property
+ def num_joints(self) -> int:
+ return len(self.bodyparts)
+
+ @property
+ def num_unique_bpts(self) -> int:
+ return len(self.unique_bpts)
+
+ @property
+ def max_num_animals(self) -> int:
+ return len(self.individuals)
+
+
+@dataclass
+class PoseDataset(Dataset):
+ """A pose dataset"""
+
+ images: list[dict]
+ annotations: list[dict]
+ parameters: PoseDatasetParameters
+ transform: A.BaseCompose | None = None
+ mode: str = "train"
+ task: Task = Task.BOTTOM_UP
+
+ def __post_init__(self):
+ self.image_path_id_map = map_image_path_to_id(self.images)
+ self.annotation_idx_map = map_id_to_annotations(self.annotations)
+ self.img_id_to_index = {
+ img["id"]: index for index, img in enumerate(self.images)
+ }
+ if self.task == Task.TOP_DOWN and (
+ self.parameters.top_down_crop_size is None
+ or self.parameters.top_down_crop_margin is None
+ ):
+ raise ValueError(
+ "You must specify a ``top_down_crop_size`` and ``top_down_crop_margin``"
+ "in your PoseDatasetParameters when the task is TOP_DOWN."
+ )
+
+ self.td_crop_size = self.parameters.top_down_crop_size
+ self.td_crop_margin = self.parameters.top_down_crop_margin
+
+ def __len__(self):
+ # TODO: TD should only return the number of annotations that aren't unique_bodyparts
+ if self.task in (Task.BOTTOM_UP, Task.DETECT):
+ return len(self.images)
+
+ return len(self.annotations)
+
+ def _get_raw_item(self, index: int) -> tuple[str, list[dict], int]:
+ """
+ Retrieve the image path and annotations for the specified index.
+
+ Args:
+ index (int): The index of the item to retrieve.
+
+ Returns:
+ tuple[str, list]: A tuple containing the image path and annotations.
+
+ Note:
+ This method is used by the __getitem__ method to fetch raw data from the dataset storage.
+ If `self.crop` is True, it returns the image path and a list with a single annotation.
+ Otherwise, it returns the image path and a list of annotations for all instances in the image.
+ """
+ img = self.images[index]
+
+ anns = [self.annotations[idx] for idx in self.annotation_idx_map[img["id"]]]
+
+ return img["file_name"], anns, img["id"]
+
+ def _get_raw_item_crop(self, index: int) -> tuple[str, list[dict], int]:
+ ann = self.annotations[index]
+
+ img = self.images[self.img_id_to_index[ann["image_id"]]]
+ return img["file_name"], [ann], img["id"]
+
+ def __getitem__(self, index: int) -> dict:
+ """
+ Gets the item at the specified index from the dataset.
+
+ Args:
+ index: ordered number of the items in the dataset
+
+ Returns:
+ dict: corresponding to the image annotations, with keys:
+ {
+ "image": image tensor of shape (c, h, w),
+ "image_id": the ID of the image,
+ "path": the filepath to the image,
+ "original_size": the original (h, w) size before transforms
+ "offsets": the (x, y) offsets to apply to the keypoints in TD mode
+ "scales": the (x, y) scales to apply to the keypoints in TD mode
+ "annotations": {
+ "keypoints": array of keypoints, invisible keypoints appear as (-1,-1)
+ "keypoints_unique": the unique keypoints, if there are any
+ "area": array of animals area in this image
+ "boxes": the bounding boxes in this image
+ "is_crowd": is_crowd annotations
+ "labels": category_id annotations for boxes
+ },
+ }
+ """
+ image_path, anns, image_id = self._get_data_based_on_task(index)
+
+ image, original_size = self._load_image(image_path)
+ (
+ keypoints,
+ keypoints_unique,
+ bboxes,
+ annotations_merged,
+ ) = self.extract_keypoints_and_bboxes(anns, image.shape)
+
+ # this is applying data augmentations before the cropping
+ # though normalization should be applied after the cropping
+ transformed = self.apply_transform_all_keypoints(
+ image, keypoints, keypoints_unique, bboxes
+ )
+ image = transformed["image"]
+ keypoints = transformed["keypoints"]
+ keypoints_unique = transformed["keypoints_unique"]
+ bboxes = transformed["bboxes"]
+ offsets = (0, 0)
+ scales = (1.0, 1.0)
+
+ if self.task == Task.TOP_DOWN:
+ if len(bboxes) > 1:
+ raise ValueError(
+ "There can only be one bbox per item in TD datasets, found "
+ f"{bboxes} for {index} (image {image_path})"
+ )
+ bboxes = bboxes.astype(int)
+
+ if bboxes[0, 2] == 0 or bboxes[0, 3] == 0:
+ # bbox was augmented out of the image; blank image, no keypoints
+ keypoints[..., 2] = 0.0
+ image = np.zeros(
+ (self.td_crop_size[1], self.td_crop_size[0], image.shape[-1]),
+ dtype=image.dtype,
+ )
+ else:
+ image, offsets, scales = top_down_crop(
+ image, bboxes[0], self.td_crop_size, margin=self.td_crop_margin,
+ )
+ keypoints[:, :, 0] = (keypoints[:, :, 0] - offsets[0]) / scales[0]
+ keypoints[:, :, 1] = (keypoints[:, :, 1] - offsets[1]) / scales[1]
+ bboxes = bboxes[:1]
+ bboxes[..., 0] = (bboxes[..., 0] - offsets[0]) / scales[0]
+ bboxes[..., 1] = (bboxes[..., 1] - offsets[1]) / scales[1]
+ bboxes[..., 2] = bboxes[..., 2] / scales[0]
+ bboxes[..., 3] = bboxes[..., 3] / scales[1]
+
+ # as a RandomBBoxTransform can be added, keypoints may be outside of the
+ # image after the crop
+ oob_mask = out_of_bounds_keypoints(keypoints, self.td_crop_size)
+ if np.sum(oob_mask) > 0:
+ keypoints[oob_mask, 2] = 0.0
+
+ if self.parameters.with_center_keypoints:
+ keypoints = self.add_center_keypoints(keypoints)
+
+ return self._prepare_final_data_dict(
+ image,
+ keypoints,
+ keypoints_unique,
+ original_size,
+ image_path,
+ bboxes,
+ image_id,
+ annotations_merged,
+ offsets,
+ scales,
+ )
+
+ def _prepare_final_data_dict(
+ self,
+ image: np.ndarray,
+ keypoints: np.ndarray,
+ keypoints_unique: np.ndarray,
+ original_size: tuple[int, int],
+ image_path: str,
+ bboxes: np.array,
+ image_id: int,
+ annotations_merged: dict,
+ offsets: tuple[int, int],
+ scales: tuple[float, float],
+ ) -> dict[str, np.ndarray | dict[str, np.ndarray]]:
+ return {
+ "image": image.transpose((2, 0, 1)),
+ "image_id": image_id,
+ "path": image_path,
+ "original_size": np.array(original_size),
+ "offsets": np.array(offsets, dtype=int),
+ "scales": np.array(scales, dtype=float),
+ "annotations": self._prepare_final_annotation_dict(
+ keypoints, keypoints_unique, bboxes, annotations_merged
+ ),
+ }
+
+ def _prepare_final_annotation_dict(
+ self,
+ keypoints: np.ndarray,
+ keypoints_unique: np.ndarray,
+ bboxes: np.array,
+ anns: dict,
+ ) -> dict[str, np.ndarray]:
+ num_animals = self.parameters.max_num_animals
+ if self.task == Task.TOP_DOWN:
+ num_animals = 1
+
+ bbox_widths = np.maximum(1, bboxes[..., 2])
+ bbox_heights = np.maximum(1, bboxes[..., 3])
+ area = bbox_widths * bbox_heights
+ if "individual_id" not in anns:
+ anns["individual_id"] = -np.ones(len(anns["category_id"]), dtype=int)
+
+ # we use ..., :3 to pass the visibility flag along
+ return {
+ "keypoints": pad_to_length(keypoints[..., :3], num_animals, 0).astype(
+ np.single
+ ),
+ "keypoints_unique": keypoints_unique[..., :3].astype(np.single),
+ "with_center_keypoints": self.parameters.with_center_keypoints,
+ "area": pad_to_length(area, num_animals, 0).astype(np.single),
+ "boxes": pad_to_length(bboxes, num_animals, 0).astype(np.single),
+ "is_crowd": pad_to_length(anns["iscrowd"], num_animals, 0).astype(int),
+ "labels": pad_to_length(anns["category_id"], num_animals, -1).astype(int),
+ "individual_ids": pad_to_length(
+ anns["individual_id"], num_animals, -1
+ ).astype(int),
+ }
+
+ def _load_image(self, image_path):
+ image = cv2.imread(image_path)
+ if self.parameters.color_mode == "RGB":
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ return image, image.shape
+
+ def _get_data_based_on_task(self, index: int) -> tuple[str, list[dict], int]:
+ """
+ Retrieve data based on the specified task.
+
+ For the 'TD' (top-down pose estimation) task:
+ - Provides a cropped image and its annotations.
+ - The shape of annotations['keypoints'] is (1, num_joints, 2).
+
+ For 'BU' and 'DT' tasks:
+ - Provides the full, non-cropped image and its annotations.
+ - The shape of annotations['keypoints'] is (max_num_animals, num_joints, 2).
+
+ Args:
+ index: Index of the item in the dataset.
+
+ Returns:
+ tuple: Tuple containing the image path, annotations, and image ID.
+ """
+ if self.task == Task.TOP_DOWN:
+ return self._get_raw_item_crop(index)
+ elif self.task in [Task.BOTTOM_UP, Task.DETECT]:
+ return self._get_raw_item(index)
+
+ raise ValueError(f"Unknown task: {self.task}")
+
+ def apply_transform_all_keypoints(
+ self,
+ image: np.ndarray,
+ keypoints: np.ndarray,
+ keypoints_unique: np.ndarray,
+ bboxes: np.ndarray,
+ ) -> dict[str, np.ndarray]:
+ """Transforms the image using this class's transform
+
+ Args:
+ image: the image to transform
+ keypoints: an array of shape (num_individuals, num_joints, 3) containing
+ the keypoints in the image
+ keypoints_unique: an array of shape (num_unique_bodyparts, 3) containing
+ the unique keypoints in the image
+ bboxes: the bounding boxes in the image
+
+ Returns:
+ the augmented image, keypoints and bboxes, in format
+ {
+ "image": (h, w, c),
+ "keypoints": (num_individuals, num_joints, 3),
+ "keypoints_unique": (num_unique_bodyparts, 3),
+ "bboxes": (4,),
+ }
+ """
+ class_labels = [
+ f"individual{i}_{bpt}"
+ for i in range(len(keypoints))
+ for bpt in self.parameters.bodyparts
+ ] + [f"unique_{bpt}" for bpt in self.parameters.unique_bpts]
+
+ all_keypoints = keypoints.reshape(-1, 3)
+ if self.parameters.num_unique_bpts > 0:
+ all_keypoints = np.concatenate([all_keypoints, keypoints_unique], axis=0)
+
+ transformed = apply_transform(
+ self.transform, image, all_keypoints, bboxes, class_labels=class_labels
+ )
+ if self.parameters.num_unique_bpts > 0:
+ keypoints = transformed["keypoints"][
+ : -self.parameters.num_unique_bpts
+ ].reshape(*keypoints.shape)
+ keypoints_unique = transformed["keypoints"][
+ -self.parameters.num_unique_bpts :
+ ]
+ keypoints_unique = keypoints_unique.reshape(
+ self.parameters.num_unique_bpts, 3
+ )
+ else:
+ keypoints = transformed["keypoints"].reshape(*keypoints.shape)
+ keypoints_unique = np.zeros((0,))
+
+ transformed["keypoints"] = keypoints
+ transformed["keypoints_unique"] = keypoints_unique
+ transformed["bboxes"] = np.array(transformed["bboxes"])
+ if len(transformed["bboxes"]) == 0:
+ transformed["bboxes"] = np.zeros((0, 4))
+
+ return transformed
+
+ @staticmethod
+ def crop(
+ image: np.ndarray,
+ keypoints,
+ coords: tuple[tuple[int, int], tuple[int, int]],
+ output_size: tuple[int, int],
+ ) -> tuple[np.ndarray, np.ndarray, tuple[int, int], tuple[int, int]]:
+ """
+ Crop the image based on a given bounding box and resize it to the desired output size.
+
+ Args:
+ image: the image to transform
+ keypoints: an array of shape (num_individuals, num_joints, 3) containing
+ the keypoints in the image
+ coords: A bounding box defined as ((x_center, y_center), (width, height)).
+ output_size: Desired size for the output cropped, padded and resized image.
+
+ Returns:
+ Cropped (and possibly padded) and resized image.
+ Offsets used for cropping.
+ Padding sizes.
+ Scale factor used to resize the image.
+ """
+ return _crop_image_keypoints(image, keypoints, coords, output_size)
+
+ def extract_keypoints_and_bboxes(
+ self, anns: list[dict], image_shape: tuple[int, int, int]
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray]]:
+ """
+ Args:
+ anns: COCO-style annotations
+ image_shape: the (h, w, c) shape of the image for which to get annotations
+
+ Returns:
+ keypoints with shape (n_annotation, num_joints, 3)
+ unique_keypoints with shape (num_unique_bpts, 3)
+ bboxes in xywh format with shape (n_annotation, 4)
+ annotations_merged, where each key contains n_annotation values
+ """
+ return _extract_keypoints_and_bboxes(
+ anns,
+ image_shape,
+ self.parameters.num_joints,
+ self.parameters.num_unique_bpts,
+ )
+
+ @staticmethod
+ def add_center_keypoints(keypoints: np.ndarray) -> np.ndarray:
+ """Adds a keypoint in the mean of each individual
+
+ Args:
+ keypoints: shape (num_idv, num_kpts, 3)
+
+ Returns:
+ keypoints with centers, of shape (num_idv, num_kpts + 1, 3)
+ """
+ num_idv = keypoints.shape[0]
+ centers = np.full((num_idv, 1, 3), np.nan)
+
+ keypoints_xy = keypoints.copy()[..., :2]
+ keypoints_xy[keypoints[..., 2] <= 0] = np.nan
+
+ # only set centers for individuals where at least 1 bodypart is visible
+ vis_mask = (~np.isnan(keypoints_xy) > 0).all(axis=2).any(axis=1)
+ if np.any(vis_mask):
+ centers[vis_mask, 0, :2] = np.nanmean(keypoints_xy[vis_mask], axis=1)
+
+ masked_centers = np.any(np.isnan(centers[:, 0, :2]), axis=1)
+ centers[masked_centers, 0, 2] = 0
+ centers[~masked_centers, 0, 2] = 2
+ np.nan_to_num(centers, copy=False, nan=0)
+
+ return np.concatenate((keypoints, centers), axis=1)
diff --git a/deeplabcut/pose_estimation_pytorch/data/dlcloader.py b/deeplabcut/pose_estimation_pytorch/data/dlcloader.py
new file mode 100644
index 0000000000..aa5d18d397
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/dlcloader.py
@@ -0,0 +1,678 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Class implementing the Loader for DeepLabCut projects"""
+from __future__ import annotations
+
+import logging
+import pickle
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import scipy.io as sio
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.core.engine import Engine
+from deeplabcut.pose_estimation_pytorch.data.base import Loader
+from deeplabcut.pose_estimation_pytorch.data.dataset import PoseDatasetParameters
+from deeplabcut.pose_estimation_pytorch.data.utils import read_image_shape_fast
+
+
+class DLCLoader(Loader):
+ """A Loader for DeepLabCut projects"""
+
+ def __init__(
+ self,
+ config: str | Path | dict,
+ trainset_index: int = 0,
+ shuffle: int = 0,
+ modelprefix: str = "",
+ ):
+ """
+ Args:
+ config: Path to the DeepLabCut project config, or the project config itself
+ trainset_index: the index of the TrainingsetFraction for which to load data
+ shuffle: the index of the shuffle for which to load data
+ modelprefix: the modelprefix for the shuffle
+ """
+ if isinstance(config, (str, Path)):
+ self._project_root = Path(config).parent
+ self._project_config = af.read_config(str(config))
+ else:
+ self._project_root = Path(config["project_path"])
+ self._project_config = config
+
+ self._shuffle = shuffle
+ self._trainset_index = trainset_index
+ self._train_frac = self._project_config["TrainingFraction"][trainset_index]
+ self._model_folder = af.get_model_folder(
+ self._train_frac,
+ shuffle,
+ self._project_config,
+ engine=Engine.PYTORCH,
+ modelprefix=modelprefix,
+ )
+ self._evaluation_folder = af.get_evaluation_folder(
+ trainFraction=self._train_frac,
+ shuffle=shuffle,
+ cfg=self._project_config,
+ engine=Engine.PYTORCH,
+ modelprefix=modelprefix,
+ )
+
+ super().__init__(
+ self._project_root
+ / self._model_folder
+ / "train"
+ / Engine.PYTORCH.pose_cfg_name
+ )
+
+ # lazy-load split and DataFrames
+ self._split: dict[str, list[int]] | None = None
+ self._loaded_df: dict[str, pd.DataFrame] | None = None
+ self._resolutions = set()
+
+ @property
+ def project_cfg(self) -> dict:
+ """Returns: the configuration for the DeepLabCut project"""
+ return self._project_config
+
+ @property
+ def df(self) -> pd.DataFrame:
+ """Returns: The ground truth dataframe. Should not be modified."""
+ return self._dfs["full"]
+
+ @property
+ def df_test(self) -> pd.DataFrame:
+ """Returns: A copy of the DataFrame containing the test data."""
+ return self._dfs["test"].copy()
+
+ @property
+ def df_train(self) -> pd.DataFrame:
+ """Returns: A copy of the DataFrame containing the training data."""
+ return self._dfs["train"].copy()
+
+ def image_resolutions(self) -> set[tuple[int, int]]:
+ """Returns: The collection of image resolutions present in the dataset"""
+ return self._resolutions
+
+ @property
+ def evaluation_folder(self) -> Path:
+ """Returns: The path to the evaluation folder"""
+ return self._project_root / self._evaluation_folder
+
+ @property
+ def project_path(self) -> Path:
+ """Returns: The path to the DeepLabCut project"""
+ return self._project_root
+
+ @property
+ def shuffle(self) -> int:
+ """Returns: the shuffle being loaded"""
+ return self._shuffle
+
+ @property
+ def train_fraction(self) -> float:
+ """Returns: the fraction of the dataset used for training"""
+ return self._train_frac
+
+ @property
+ def split(self) -> dict[str, list[int]]:
+ if self._split is None:
+ self._split = self.load_split(
+ self._project_config, self._trainset_index, self.shuffle
+ )
+
+ return self._split
+
+ def get_dataset_parameters(self) -> PoseDatasetParameters:
+ """Retrieves dataset parameters based on the instance's configuration.
+
+ Returns:
+ An instance of the PoseDatasetParameters with the parameters set.
+ """
+ crop_cfg = self.model_cfg["data"]["train"].get("top_down_crop", {})
+ crop_w, crop_h = crop_cfg.get("width", 256), crop_cfg.get("height", 256)
+ crop_margin = crop_cfg.get("margin", 0)
+
+ return PoseDatasetParameters(
+ bodyparts=self.model_cfg["metadata"]["bodyparts"],
+ unique_bpts=self.model_cfg["metadata"]["unique_bodyparts"],
+ individuals=self.model_cfg["metadata"]["individuals"],
+ with_center_keypoints=self.model_cfg.get("with_center_keypoints", False),
+ color_mode=self.model_cfg.get("color_mode", "RGB"),
+ top_down_crop_size=(crop_w, crop_h),
+ top_down_crop_margin=crop_margin,
+ )
+
+ def load_data(self, mode: str = "train") -> dict:
+ """Loads DeepLabCut data into COCO-style annotations
+
+ This function reads data from h5 file, split the data and returns it in
+ COCO-like format
+
+ Args:
+ mode: mode indicating whether to use 'train' or 'test' data.
+
+ Raises:
+ AttributeError: if the specified mode (train or test) does not exist.
+
+ Returns:
+ the coco-style annotations
+ """
+ if mode not in ["train", "test"]:
+ raise AttributeError(f"Unknown mode: {mode}")
+ if mode not in self._dfs:
+ raise ValueError(f"No split for: {mode} (found {self._dfs.keys()})")
+ if self._dfs[mode] is None:
+ raise ValueError(f"No data in {mode} split for this shuffle!")
+
+ params = self.get_dataset_parameters()
+ data = self.to_coco(str(self._project_root), self._dfs[mode], params)
+ with_bbox = self._compute_bboxes(
+ data["images"], data["annotations"], method="keypoints"
+ )
+ data["annotations"] = with_bbox
+ return data
+
+ def load_ground_truth(
+ self,
+ config: dict,
+ trainset_index: int,
+ shuffle: int,
+ ) -> tuple[dict[str, pd.DataFrame], set[tuple[int, int]]]:
+ """Loads the ground truth dataset for a DeepLabCut project.
+
+ Args:
+ config: the DeepLabCut project configuration file
+ trainset_index: the TrainingsetFraction for which to load data
+ shuffle: the index of the shuffle for which to load data
+
+ Returns: ground_truth_dataframes, image_resolutions
+ ground_truth_dataframes: a dictionary containing the different DataFrames
+ for the annotated DeepLabCut data for the current iteration
+ image_resolutions: all possible image resolutions in the dataset
+
+ Raises:
+ ValueError: if the data contained in the ground truth HDF does not contain
+ a dataframe.
+ """
+ trainset_dir = Path(config["project_path"]) / af.get_training_set_folder(config)
+ dataset_path = f"CollectedData_{config['scorer']}.h5"
+ train_frac = int(100 * config["TrainingFraction"][trainset_index])
+ project_id = f"{config['Task']}_{config['scorer']}"
+ dataset_file = trainset_dir / f"{project_id}{train_frac}shuffle{shuffle}"
+ params = self.get_dataset_parameters()
+
+ # as in TF DeepLabCut, load the training data from the .mat/.pickle file
+ if config.get("multianimalproject", False):
+ image_sizes, df_train = _load_pickle_dataset(
+ dataset_file.with_suffix(".pickle"),
+ config["scorer"],
+ params=params,
+ )
+ else:
+ image_sizes, df_train = _load_mat_dataset(
+ dataset_file.with_suffix(".mat"),
+ config["scorer"],
+ params=params,
+ )
+
+ # load the full dataset file
+ df = pd.read_hdf(trainset_dir / dataset_path)
+ if not isinstance(df, pd.DataFrame):
+ raise ValueError(
+ f"The ground truth data in {trainset_dir} must contain a DataFrame! "
+ f"Found {df}"
+ )
+
+ # load the data splits, check that there's nothing suspect
+ dfs = self.split_data(df, self.split)
+ dfs["full"] = df
+ # let's not validate for now
+ # dfs = _validate_dataframes(dfs, df_train)
+ return dfs, image_sizes
+
+ @staticmethod
+ def load_split(
+ config: dict,
+ trainset_index: int = 0,
+ shuffle: int = 0,
+ ) -> dict[str, list[int]]:
+ """Loads the train/test split for a DeepLabCut shuffle
+
+ Args:
+ config: the DeepLabCut project config
+ trainset_index: the TrainingsetFraction for which to load data
+ shuffle: the index of the shuffle for which to load data
+
+ Return:
+ the {"train": [train_ids], "test": [test_ids]} data split
+ """
+ trainset_dir = Path(config["project_path"]) / af.get_training_set_folder(config)
+ train_frac = int(100 * config["TrainingFraction"][trainset_index])
+ shuffle_id = f"{config['Task']}_{train_frac}shuffle{shuffle}.pickle"
+ doc_path = trainset_dir / f"Documentation_data-{shuffle_id}"
+
+ with open(doc_path, "rb") as f:
+ meta = pickle.load(f)
+
+ train_ids = [int(i) for i in meta[1]]
+ test_ids = [int(i) for i in meta[2]]
+ return {"train": train_ids, "test": test_ids}
+
+ @staticmethod
+ def split_data(
+ dlc_df: pd.DataFrame,
+ split: dict[str, list[int]],
+ ) -> dict[str, pd.DataFrame | None]:
+ """
+ Splits a DeepLabCut DataFrame into train/test dataframes
+
+ Args:
+ dlc_df: the dataframe containing the labeled data
+ split: the train/test indices
+
+ Returns:
+ a dictionary containing the same keys as the split dictionary, where the
+ values are the rows of dlc_df with index in the split, or None if there are
+ no indices in that split
+ """
+ split_dfs = {}
+ for k, indices in split.items():
+ if len(indices) == 0:
+ split_dfs[k] = None
+ else:
+ split_dfs[k] = dlc_df.iloc[indices]
+ return split_dfs
+
+ @staticmethod
+ def to_coco(
+ project_root: str | Path,
+ df: pd.DataFrame,
+ parameters: PoseDatasetParameters,
+ ) -> dict:
+ """Formerly Shaokai's function
+
+ Args:
+ project_root: the path to the project root
+ df: the DLC-format annotation dataframe to convert to a COCO-format dict
+ parameters: the parameters for pose estimation
+
+ Returns:
+ the coco format data
+ """
+ with_individuals = "individuals" in df.columns.names
+ if not with_individuals and (
+ len(parameters.individuals) > 1 or len(parameters.unique_bpts) > 0
+ ):
+ raise ValueError(
+ "The DataFrame contains single-animal annotations (for a single, "
+ "individual), but the parameters suggest this is a multi-animal project"
+ f": {parameters} (with multiple individuals or unique bodyparts)"
+ )
+
+ categories = [
+ {
+ "id": 1,
+ "name": "animals",
+ "supercategory": "animal",
+ "keypoints": parameters.bodyparts,
+ },
+ ]
+ individuals = [idv for idv in parameters.individuals]
+ if len(parameters.unique_bpts) > 0:
+ individuals += ["single"]
+ categories.append(
+ {
+ "id": 2,
+ "name": "unique_bodypart",
+ "supercategory": "animal",
+ "keypoints": parameters.unique_bpts,
+ }
+ )
+
+ anns, images = [], []
+ base_path = Path(project_root)
+ for idx, row in df.iterrows():
+ image_id = len(images) + 1
+ rel_path = Path(*idx) if isinstance(idx, tuple) else Path(str(idx))
+ path = str(base_path / rel_path)
+ _, height, width = read_image_shape_fast(path)
+ images.append(
+ {
+ "id": image_id,
+ "file_name": path,
+ "width": width,
+ "height": height,
+ }
+ )
+
+ for idv_idx, idv in enumerate(individuals):
+ category_id = 1
+ individual_id = idv_idx
+ if with_individuals:
+ if idv == "single":
+ category_id = 2
+ individual_id = -1
+ data = row.xs(idv, level="individuals")
+ else:
+ data = row
+
+ raw_keypoints = data.to_numpy().reshape((-1, 2))
+ keypoints = np.zeros((len(raw_keypoints), 3))
+ keypoints[:, :2] = raw_keypoints
+ is_visible = np.logical_and(
+ ~pd.isnull(raw_keypoints).all(axis=1),
+ np.logical_and(
+ np.logical_and(
+ 0 < keypoints[..., 0],
+ keypoints[..., 0] < width,
+ ),
+ np.logical_and(
+ 0 < keypoints[..., 1],
+ keypoints[..., 1] < height,
+ ),
+ ),
+ )
+ keypoints[:, 2] = np.where(is_visible, 2, 0)
+ num_keypoints = is_visible.sum()
+ if num_keypoints > 0:
+ anns.append(
+ {
+ "id": len(anns) + 1,
+ "image_id": image_id,
+ "category_id": category_id,
+ "individual": idv,
+ "individual_id": individual_id,
+ "num_keypoints": num_keypoints,
+ "keypoints": keypoints,
+ "iscrowd": 0,
+ }
+ )
+
+ return {"annotations": anns, "categories": categories, "images": images}
+
+ @property
+ def _dfs(self) -> dict[str, pd.DataFrame]:
+ """Lazy-loading of the training dataset dataframes"""
+ if self._loaded_df is None:
+ self._loaded_df, image_sizes = self.load_ground_truth(
+ self._project_config,
+ trainset_index=self._trainset_index,
+ shuffle=self.shuffle,
+ )
+ self._resolutions = self._resolutions.union(image_sizes)
+
+ return self._loaded_df
+
+
+def _load_mat_dataset(
+ file: Path,
+ scorer: str,
+ params: PoseDatasetParameters,
+) -> tuple[set[tuple[int, int]], pd.DataFrame]:
+ """Loads the training dataset stored as a .mat file
+
+ Returns: images_sizes, dlc_dataset
+ images_sizes: all possible images sizes in the dataset
+ dlc_dataset: the dataset in a DLC-format DataFrame
+ """
+ if not params.max_num_animals == 1:
+ raise RuntimeError(
+ f"Cannot load a multi-animal pose dataset from a `.mat` file ({file})"
+ )
+
+ raw_data = sio.loadmat(str(file))
+ dataset = raw_data["dataset"]
+ num_images = dataset.shape[1]
+
+ image_sizes = set()
+ index, data = [], []
+ for i in range(num_images):
+ item = dataset[0, i]
+
+ # add the image size
+ c, h, w = item[1][0]
+ image_sizes.add((h, w))
+
+ # parse image path
+ raw_path = item[0][0]
+ if isinstance(raw_path, str):
+ image_path = Path(raw_path).parts[-3:]
+ else:
+ image_path = tuple([p.strip() for p in raw_path])
+ index.append(image_path)
+
+ # parse data
+ keypoints = np.zeros((1, params.num_joints, 2))
+ keypoints.fill(np.nan)
+ if len(item) >= 3:
+ joints = item[2][0][0]
+ for joint_id, x, y in joints:
+ keypoints[0, joint_id, 0] = x
+ keypoints[0, joint_id, 1] = y
+
+ joint_id = joints[:, 0]
+ if joint_id.size != 0: # make sure joint ids are 0-indexed
+ assert (joint_id < params.num_joints).any()
+ joints[:, 0] = joint_id
+
+ data.append(keypoints)
+
+ dataframe = pd.DataFrame(
+ data=np.stack(data, axis=0).reshape((num_images, -1)),
+ index=pd.MultiIndex.from_tuples(index),
+ columns=build_dlc_dataframe_columns(scorer, params, False),
+ )
+ dataframe = dataframe.sort_index(axis=0)
+ return image_sizes, dataframe
+
+
+def _load_pickle_dataset(
+ file: Path,
+ scorer: str,
+ params: PoseDatasetParameters,
+) -> tuple[set[tuple[int, int]], pd.DataFrame]:
+ """Loads the training dataset stored as a .mat file
+
+ Returns: images_sizes, dlc_dataset
+ images_sizes: all possible images sizes in the dataset
+ dlc_dataset: the dataset in a DLC-format DataFrame
+ """
+ with open(file, "rb") as f:
+ raw_data = pickle.load(f)
+
+ num_images = len(raw_data)
+ image_sizes = set()
+ index, data = [], []
+ data_unique = None
+ if params.num_unique_bpts > 0:
+ data_unique = []
+
+ for image_data in raw_data:
+ # add image path
+ index.append(image_data["image"])
+
+ # add image size
+ c, h, w = image_data["size"]
+ image_sizes.add((h, w))
+
+ # add keypoints
+ keypoints = np.zeros((params.max_num_animals, params.num_joints, 2))
+ keypoints.fill(np.nan)
+ keypoints_unique = None
+ for idv_idx, idv_bodyparts in image_data.get("joints", {}).items():
+ if idv_idx < params.max_num_animals:
+ for joint_id, x, y in idv_bodyparts:
+ bodypart = int(joint_id)
+ keypoints[idv_idx, bodypart, 0] = x
+ keypoints[idv_idx, bodypart, 1] = y
+
+ elif (
+ idv_idx == params.max_num_animals
+ and data_unique is not None
+ and keypoints_unique is None
+ ):
+ keypoints_unique = np.zeros((params.num_unique_bpts, 2))
+ keypoints_unique.fill(np.nan)
+ for joint_id, x, y in idv_bodyparts:
+ unique_bpt_id = int(joint_id) - params.num_joints
+ keypoints_unique[unique_bpt_id, 0] = x
+ keypoints_unique[unique_bpt_id, 1] = y
+
+ else:
+ raise ValueError(f"Malformed dataset: {params}, {image_data}")
+
+ data.append(keypoints)
+ if data_unique is not None:
+ if keypoints_unique is None:
+ keypoints_unique = np.zeros((params.num_unique_bpts, 2))
+ keypoints_unique.fill(np.nan)
+ data_unique.append(keypoints_unique)
+
+ data = np.stack(data, axis=0).reshape((num_images, -1))
+ if data_unique is not None:
+ data_unique = np.stack(data_unique, axis=0).reshape((num_images, -1))
+ data = np.concatenate([data, data_unique], axis=1)
+
+ dataframe = pd.DataFrame(
+ data=data,
+ index=pd.MultiIndex.from_tuples(index),
+ columns=build_dlc_dataframe_columns(scorer, params, False),
+ )
+ dataframe = dataframe.sort_index(axis=0)
+ return image_sizes, dataframe
+
+
+def _validate_dataframes(
+ dfs: dict[str, pd.DataFrame],
+ df_train: pd.DataFrame,
+ strict: bool = False,
+) -> dict[str, pd.DataFrame]:
+ """Validates the training/test DataFrames
+
+ Performs the following validation steps:
+ 1. Checks that the training data loaded from CollectedData.h5 matches the
+ training data stored in the ".mat" or ".pickle" file.
+ 2. Checks that there are no duplicate entries in the DataFrames (if there are
+ any, removes them)
+ 3. Checks that there is no data leak between the training and test set (if there
+ is, prints a warning)
+
+ Args:
+ dfs: the "full" and split DataFrames loaded from the H5 file
+ df_train: the training data loaded from the ".mat" or ".pickle" file
+ strict: Whether to fail if the data does not pass validation (instead of
+ attempting a fix).
+
+ Returns:
+ The validated and sanitized DataFrames
+
+ Raises:
+ ValueError: if strict and there is a small fixable error, or if there are images
+ that are present in both the training and test set.
+ """
+ error = False
+
+ # checks that all images in the .pickle/.mat file are in the HDF
+ pickle_train_images = set(df_train.index)
+ hdf_train_images = set(dfs["train"].index)
+ missing_images = pickle_train_images - hdf_train_images
+ extra_images = hdf_train_images - pickle_train_images
+ if len(missing_images) > 0:
+ error = True
+ logging.debug(
+ f"Found images in the dataset file which were not in H5: {missing_images}"
+ )
+ if len(extra_images) > 0:
+ error = True
+ logging.debug(
+ f"Found images in the H5 file which were not in the dataset: {extra_images}"
+ )
+
+ # checks that the data is close for the similar images
+ train_index = list(hdf_train_images.intersection(pickle_train_images))
+ data_h5 = np.nan_to_num(dfs["full"].loc[train_index], nan=-1)
+ data_pickle_mat = np.nan_to_num(df_train, nan=-1)
+ if not np.isclose(data_h5, data_pickle_mat, atol=0.1).all():
+ error = True
+ logging.debug(
+ "Found differences between the training-dataset HDF (.h5) data and the "
+ "training data found. This might be the case if you refined your data "
+ "after creating the dataset, and then created a new shuffle."
+ )
+
+ # checks that there are no duplicate entries
+ dfs_clean = {}
+ for split, df in dfs.items():
+ dup = df.index.duplicated(keep="first")
+ num_dup = dup.sum()
+ if dup.sum() > 0:
+ error = True
+ logging.debug(f"Found {num_dup} duplicates in {split}: {df[dup].index}")
+ dfs_clean[split] = df[~dup]
+ else:
+ dfs_clean[split] = df[~dup]
+
+ # check for leaks
+ if dfs["test"] is not None:
+ train_images = set(dfs["train"].index)
+ test_images = set(dfs["test"].index)
+ leak = train_images.intersection(test_images)
+ if len(leak) > 0:
+ logging.warning(
+ f"Found images both in the training and test set: {leak}! To resolve "
+ "this issue please try the following:\n"
+ f" 1. Check that each video is listed exactly once in your project's"
+ f"`config.yaml`\n"
+ f" 2. Make sure all of your videos have different names."
+ f" 3. You can use `dropduplicatesinannotatinfiles` and "
+ f"`comparevideolistsanddatafolders` to ensure that there are no more "
+ f"duplicates"
+ f" 3. Switch to a new iteration and create a fresh training dataset"
+ )
+
+ if error and strict:
+ raise ValueError(f"Found errors when validating the dataset")
+
+ return dfs
+
+
+def build_dlc_dataframe_columns(
+ scorer: str,
+ parameters: PoseDatasetParameters,
+ with_likelihood: bool,
+) -> pd.MultiIndex:
+ """Builds the columns for a DeepLabCut DataFrame
+
+ Args:
+ scorer: the scorer name
+ parameters: the parameters for the project
+ with_likelihood: whether the DataFrame contains pose likelihood
+
+ Returns:
+ the multi-index columns for the DataFrame
+ """
+ levels = ["scorer", "individuals", "bodyparts", "coords"]
+ kpt_entries = ["x", "y"]
+ if with_likelihood:
+ kpt_entries.append("likelihood")
+
+ columns = []
+ for i in parameters.individuals:
+ for b in parameters.bodyparts:
+ columns += [(scorer, i, b, entry) for entry in kpt_entries]
+
+ for unique_bpt in parameters.unique_bpts:
+ columns += [(scorer, "single", unique_bpt, entry) for entry in kpt_entries]
+
+ return pd.MultiIndex.from_tuples(columns, names=levels)
diff --git a/deeplabcut/pose_estimation_pytorch/data/helper.py b/deeplabcut/pose_estimation_pytorch/data/helper.py
new file mode 100644
index 0000000000..de1d632b5e
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/helper.py
@@ -0,0 +1,96 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABCMeta
+
+
+def cfg_getter(key, default=None):
+ def _getter(cfg):
+ return cfg.get(key, default)
+
+ return _getter
+
+
+def class_property(func, arg_func):
+ """
+ Decorator to create a class property.
+
+ Parameters:
+ - func: Callable that represents the logic of the property.
+ - arg_func: Callable that provides the arguments for `func`.
+
+ Returns:
+ - A property with the logic encapsulated in `func` and arguments derived from `arg_func`.
+ """
+
+ def decorator_wrapper(method):
+ def wrapper(self):
+ return func(arg_func(self))
+
+ return property(wrapper)
+
+ return decorator_wrapper
+
+
+class PropertyMeta(type):
+ """
+ Metaclass for creating class properties in a more organized and systematic manner.
+
+ This metaclass allows a class to define its properties using a simple dictionary
+ structure (`properties`). The dictionary keys represent the property names,
+ while the values are tuples containing two callables:
+ 1. The function that represents the logic of the property.
+ 2. The function that provides the arguments for the logic function.
+
+ Usage:
+ class MyClass(metaclass=PropertyMeta):
+ properties = {
+ 'property_name': (logic_function, arguments_function),
+ # ... more properties ...
+ }
+
+ For each property specified in the `properties` dictionary, the metaclass will
+ generate a real property that uses the logic from `logic_function` and
+ arguments from `arguments_function`.
+
+ Attributes:
+ - properties (dict): Dictionary containing property names as keys and tuples
+ of (logic_function, arguments_function) as values.
+ """
+
+ def __new__(cls, name, bases, attrs):
+ if "properties" not in attrs:
+ raise AttributeError(f"{name} must define a 'properties' dictionary.")
+ properties = attrs.get("properties", {})
+ for prop_name, (func, arg_func) in properties.items():
+ attrs[prop_name] = class_property(func, arg_func)(lambda self: None)
+ return super().__new__(cls, name, bases, attrs)
+
+
+class CombinedPropertyMeta(ABCMeta, PropertyMeta):
+ """
+ Combined metaclass that integrates the functionalities of both `ABCMeta` and `BasePropertyMeta`.
+
+ This metaclass is useful in scenarios where a class needs to use both abstract methods (from `ABCMeta`)
+ and the property definition utilities provided by `BasePropertyMeta`.
+
+ By using this metaclass, a class can be both an abstract class (with abstract methods and/or properties)
+ and can also define properties in the structured manner facilitated by `PropertyMeta`.
+
+ Inherits:
+ - ABCMeta: Metaclass for base classes that include abstract methods.
+ - PropertyMeta: Metaclass that facilitates structured property definitions.
+
+ Note:
+ When defining a class using `CombinedPropertyMeta`, ensure that the class also inherits
+ from `ABC` to make it compatible with the `ABCMeta` behavior.
+ """
diff --git a/deeplabcut/pose_estimation_pytorch/data/image.py b/deeplabcut/pose_estimation_pytorch/data/image.py
new file mode 100644
index 0000000000..115b76beb2
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/image.py
@@ -0,0 +1,246 @@
+from __future__ import annotations
+
+import copy
+
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms.functional as F
+
+from deeplabcut.pose_estimation_pytorch.data.utils import _compute_crop_bounds
+
+
+def resize_and_random_crop(
+ image: np.ndarray,
+ targets: dict,
+ size: int | tuple[int, int],
+ max_size: int | None = None,
+ max_shift: int | None = None,
+) -> tuple[torch.tensor, dict]:
+ """Resizes images while preserving their aspect ratio
+
+ If size is an integer: resizes to square images.
+ First, resizes the image so that it's short side is equal to `size`. If this
+ makes its long side greater than `max_size`, resizes the long side to `max_size`
+ and the short side to the corresponding value to preserve the aspect ratio.
+
+ Then, the image is cropped to a size-by-size square with a random crop.
+
+ If size is a tuple, resize images to (w=size[1], h=size[0])
+ First, rescales the image while preserving the aspect ratio such that both its
+ width and height are greater or equal to the target width/height for the image
+ (where either the width/height is the target width/height). If this makes its
+ long side greater than `max_size`, resizes the long side to `max_size`.
+
+ Then, the image is cropped to (w=size[1], h=size[0]) with a random crop.
+
+ Args:
+ image: an image of shape (C, H, W)
+ targets: the dictionary containing targets
+ size: the size of the output image (it will be square)
+ max_size: if defined, the maximum size of any side of the output image
+ max_shift: the maximum shift for the crop after resizing
+
+ Returns: image, targets
+ the resized image as a PyTorch tensor
+ the updated targets in the resized image
+ """
+
+ def get_resize_hw(
+ original_size: tuple[int, int], tgt_short_side: int, max_long_side: int | None
+ ) -> tuple[int, int]:
+ short_side, long_side = min(*original_size), max(*original_size)
+ tgt_long_side = int((tgt_short_side / short_side) * long_side)
+
+ # if the image's long side will be too big, make the image smaller
+ if max_long_side is not None and tgt_long_side > max_long_side:
+ tgt_long_side = max_long_side
+ tgt_short_side = int((tgt_long_side / long_side) * short_side)
+
+ # height is the short side
+ if original_size[0] < original_size[1]:
+ return tgt_short_side, tgt_long_side
+
+ # width is the short side
+ return tgt_long_side, tgt_short_side
+
+ def get_resize_preserve_ratio(
+ oh: int, ow: int, tgt_h: int, tgt_w: int, max_long_side: int | None
+ ) -> tuple[int, int]:
+ w_scale = ow / tgt_w
+ h_scale = oh / tgt_h
+ if h_scale <= w_scale:
+ h = tgt_h
+ w = int(ow * (tgt_h / oh))
+ else:
+ h = int(oh * (tgt_w / ow))
+ w = tgt_w
+
+ # if the image's long side will be too big, make the image smaller
+ long_side = max(h, w)
+ if max_long_side is not None and long_side > max_long_side:
+ if h <= w:
+ w = max_long_side
+ h = int(oh * (max_long_side / ow))
+ else:
+ w = int(ow * (max_long_side / oh))
+ h = max_long_side
+
+ return h, w
+
+ oh, ow = image.shape[1:]
+ if isinstance(size, int):
+ h, w = get_resize_hw((oh, ow), tgt_short_side=size, max_long_side=max_size)
+ tgt_h, tgt_w = size, size
+ else:
+ h, w = get_resize_preserve_ratio(
+ oh, ow, size[0], size[1], max_long_side=max_size
+ )
+ tgt_h, tgt_w = size
+
+ scale_x, scale_y = ow / w, oh / h
+ scaled_image = F.resize(torch.tensor(image), [h, w])
+
+ # shift the image
+ if max_shift is None:
+ max_shift = 0
+ extra_x, extra_y = max(0, w - tgt_w), max(0, h - tgt_h)
+ offset_x = np.random.randint(
+ max(-tgt_w // 2, -max(0, tgt_w - w) - max_shift),
+ min(max_shift + extra_x, extra_x + (min(w, tgt_w) // 2)),
+ )
+ offset_y = np.random.randint(
+ max(-tgt_h // 2, -max(0, tgt_h - h) - max_shift),
+ min(max_shift + extra_y, extra_y + (min(h, tgt_h) // 2)),
+ )
+
+ # 0-pads, then crops if image size is smaller than output size along any edge
+ scaled_cropped_image = F.crop(scaled_image, offset_y, offset_x, tgt_h, tgt_w)
+
+ # update targets
+ targets = copy.deepcopy(targets)
+
+ # update scales and offsets
+ sx, sy = targets["scales"]
+ ox, oy = targets["offsets"]
+ targets["offsets"] = ox + (offset_x * sx), oy + (offset_y * sy)
+ targets["scales"] = sx * scale_x, sy * scale_y
+
+ # update annotations
+ anns = targets.get("annotations", {})
+
+ kpt_scale = np.array([scale_x, scale_y])
+ kpt_offset = np.array([offset_x, offset_y])
+ for kpt_key in ["keypoints", "keypoints_unique"]:
+ keypoints = anns.get(kpt_key)
+ if keypoints is not None and len(keypoints) > 0:
+ scaled_kpts = keypoints.copy()
+ scaled_kpts[..., :2] = (scaled_kpts[..., :2] / kpt_scale) - kpt_offset
+ scaled_kpts[(scaled_kpts[..., 0] >= tgt_w)] = -1
+ scaled_kpts[(scaled_kpts[..., 1] >= tgt_h)] = -1
+ scaled_kpts[(scaled_kpts[..., :2] < 0).any(axis=-1)] = -1
+ anns[kpt_key] = scaled_kpts
+
+ bbox_scale = np.array([scale_x, scale_y, scale_x, scale_y])
+ bbox_offset = np.array([offset_x, offset_y, 0, 0])
+ for bbox_key in ["boxes"]:
+ boxes = anns.get(bbox_key)
+ if boxes is not None and len(boxes) > 0:
+ scaled_boxes = (boxes / bbox_scale) - bbox_offset
+ scaled_boxes = _compute_crop_bounds(
+ scaled_boxes, (tgt_h, tgt_w, 3), remove_empty=False,
+ )
+ anns[bbox_key] = scaled_boxes
+
+ area = anns.get("area")
+ if area is not None:
+ if "boxes" in anns: # recompute areas from the new bounding boxes
+ widths = np.maximum(anns["boxes"][..., 2], 1)
+ heights = np.maximum(anns["boxes"][..., 3], 1)
+ anns["area"] = widths * heights
+ else: # just rescale
+ scaled_area = area * (scale_x * scale_y)
+ anns["area"] = scaled_area
+
+ return scaled_cropped_image, targets
+
+
+def top_down_crop(
+ image: np.ndarray,
+ bbox: np.ndarray,
+ output_size: tuple[int, int],
+ margin: int = 0,
+ center_padding: bool = False,
+) -> tuple[np.array, tuple[int, int], tuple[float, float]]:
+ """
+ Crops images around bounding boxes for top-down pose estimation. Computes offsets so
+ that coordinates in the original image can be mapped to the cropped one;
+
+ x_cropped = (x - offset_x) / scale_x
+ x_cropped = (y - offset_y) / scale_y
+
+ Bounding boxes are expected to be in COCO-format (xywh).
+
+ Args:
+ image: (h, w, c) the image to crop
+ bbox: (4,) the bounding box to crop around
+ output_size: the (width, height) of the output cropped image
+ margin: a margin to add around the bounding box before cropping
+ center_padding: whether to center the image in the padding if any is needed
+
+ Returns:
+ cropped_image, (offset_x, offset_y), (scale_x, scale_y)
+ """
+ image_h, image_w, c = image.shape
+ out_w, out_h = output_size
+ x, y, w, h = bbox
+
+ cx = x + w / 2
+ cy = y + h / 2
+ w += 2 * margin
+ h += 2 * margin
+
+ input_ratio = w / h
+ output_ratio = out_w / out_h
+ if input_ratio > output_ratio: # h/w < h0/w0 => h' = w * h0/w0
+ h = w / output_ratio
+ elif input_ratio < output_ratio: # w/h < w0/h0 => w' = h * w0/h0
+ w = h * output_ratio
+
+ # cx,cy,w,h will now give the right ratio -> check if padding is needed
+ x1, y1 = int(round(cx - (w / 2))), int(round(cy - (h / 2)))
+ x2, y2 = int(round(cx + (w / 2))), int(round(cy + (h / 2)))
+
+ # pad symmetrically - compute total padding across axis
+ pad_left, pad_right, pad_top, pad_bottom = 0, 0, 0, 0
+ if x1 < 0:
+ pad_left = -x1
+ x1 = 0
+ if x2 > image_w:
+ pad_right = x2 - image_w
+ x2 = image_w
+ if y1 < 0:
+ pad_top = -y1
+ y1 = 0
+ if y2 > image_h:
+ pad_bottom = y2 - image_h
+ y2 = image_h
+
+ w, h = x2 - x1, y2 - y1
+ pad_x = pad_left + pad_right
+ pad_y = pad_top + pad_bottom
+ if center_padding:
+ pad_left = pad_x // 2
+ pad_top = pad_y // 2
+
+ # crop the pixels we care about
+ image_crop = np.zeros((h + pad_y, w + pad_x, c), dtype=image.dtype)
+ image_crop[pad_top:pad_top + h, pad_left:pad_left + w] = image[y1:y2, x1:x2]
+
+ # resize the cropped image
+ image = cv2.resize(image_crop, (out_w, out_h), interpolation=cv2.INTER_LINEAR)
+
+ # compute scale and offset
+ offset = x1 - pad_left, y1 - pad_top
+ scale = (w + pad_x) / out_w, (h + pad_y) / out_h
+ return image, offset, scale
diff --git a/deeplabcut/pose_estimation_pytorch/data/postprocessor.py b/deeplabcut/pose_estimation_pytorch/data/postprocessor.py
new file mode 100644
index 0000000000..7775db41b7
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/postprocessor.py
@@ -0,0 +1,549 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Post-process predictions made by models"""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Any
+
+import numpy as np
+
+from deeplabcut.pose_estimation_pytorch.data.preprocessor import Context
+from deeplabcut.pose_estimation_pytorch.post_processing.identity import assign_identity
+
+
+class Postprocessor(ABC):
+ """A post-processor can be called on the output of a model
+ TODO: Documentation
+ """
+
+ @abstractmethod
+ def __call__(self, predictions: Any, context: Context) -> Any:
+ """
+ Post-processes the outputs of a model into a single prediction.
+
+ Args:
+ predictions: the predictions made by the model on a single image
+ context: the context returned by the pre-processor with the image
+
+ Returns:
+ a single post-processed prediction
+ """
+ pass
+
+
+def build_bottom_up_postprocessor(
+ max_individuals: int,
+ num_bodyparts: int,
+ num_unique_bodyparts: int,
+ with_identity: bool = False,
+ with_backbone_features: bool = False,
+) -> ComposePostprocessor:
+ """Creates a postprocessor for bottom-up pose estimation (or object detection)
+
+ Args:
+ max_individuals: the maximum number of individuals in a single image
+ num_bodyparts: the number of bodyparts output by the model
+ num_unique_bodyparts: the number of unique_bodyparts output by the model
+ with_identity: whether the model has an identity head
+ with_backbone_features: When True, the backbone features are extracted from
+ the output and saved in a `features` key. The `PoseModel` must have its
+ `output_features` attribute set to True, or this will raise an Exception.
+
+ Returns:
+ A default bottom-up Postprocessor
+ """
+ keys_to_concatenate = {"bodyparts": ("bodypart", "poses")}
+ empty_shapes = {"bodyparts": (num_bodyparts, 3)}
+ keys_to_rescale = ["bodyparts"]
+
+ if num_unique_bodyparts > 0:
+ keys_to_concatenate["unique_bodyparts"] = ("unique_bodypart", "poses")
+ empty_shapes["unique_bodyparts"] = (num_bodyparts, 3)
+ keys_to_rescale.append("unique_bodyparts")
+
+ if with_identity:
+ keys_to_concatenate["identity_heatmap"] = ("identity", "heatmap")
+ empty_shapes["identity_heatmap"] = (1, 1, max_individuals)
+
+ if with_backbone_features:
+ keys_to_concatenate["features"] = ("backbone", "features")
+ empty_shapes["features"] = (num_bodyparts, 0, 1)
+
+ components = [
+ ConcatenateOutputs(
+ keys_to_concatenate=keys_to_concatenate,
+ empty_shapes=empty_shapes,
+ create_empty_outputs=True,
+ ),
+ ]
+
+ if with_identity:
+ components.append(
+ PredictKeypointIdentities(
+ identity_key="identity_scores",
+ identity_map_key="identity_heatmap",
+ pose_key="bodyparts",
+ keep_id_maps=False,
+ )
+ )
+
+ components += [
+ RescaleAndOffset(
+ keys_to_rescale=keys_to_rescale,
+ mode=RescaleAndOffset.Mode.KEYPOINT,
+ ),
+ PadOutputs(
+ max_individuals={
+ "bodyparts": max_individuals,
+ "identity_scores": max_individuals,
+ },
+ pad_value=-1,
+ ),
+ ]
+
+ if with_identity:
+ components.append(
+ AssignIndividualIdentities(
+ identity_key="identity_scores", pose_key="bodyparts",
+ )
+ )
+
+ return ComposePostprocessor(components=components)
+
+
+def build_top_down_postprocessor(
+ max_individuals: int,
+ num_bodyparts: int,
+ num_unique_bodyparts: int,
+ with_backbone_features: bool = False,
+) -> Postprocessor:
+ """Creates a postprocessor for top-down pose estimation
+
+ Args:
+ max_individuals: the maximum number of individuals in a single image
+ num_bodyparts: the number of bodyparts output by the model
+ num_unique_bodyparts: the number of unique_bodyparts output by the model
+ with_backbone_features: When True, the backbone features are extracted from
+ the output and saved in a `features` key. The `PoseModel` must have its
+ `output_features` attribute set to True, or this will raise an Exception.
+
+ Returns:
+ A default top-down Postprocessor
+ """
+ keys_to_concatenate = {"bodyparts": ("bodypart", "poses")}
+ empty_shapes = {"bodyparts": (num_bodyparts, 3)}
+ keys_to_rescale = ["bodyparts"]
+ if num_unique_bodyparts > 0:
+ keys_to_concatenate["unique_bodyparts"] = ("unique_bodypart", "poses")
+ empty_shapes["unique_bodyparts"] = (num_unique_bodyparts, 3)
+ keys_to_rescale.append("unique_bodyparts")
+
+ if with_backbone_features:
+ keys_to_concatenate["features"] = ("backbone", "features")
+ empty_shapes["features"] = (num_bodyparts, 0, 1)
+
+ return ComposePostprocessor(
+ components=[
+ ConcatenateOutputs(
+ keys_to_concatenate=keys_to_concatenate,
+ empty_shapes=empty_shapes,
+ create_empty_outputs=True,
+ ),
+ RescaleAndOffset(
+ keys_to_rescale=keys_to_rescale,
+ mode=RescaleAndOffset.Mode.KEYPOINT_TD,
+ ),
+ AddContextToOutput(keys=["bboxes", "bbox_scores"]),
+ PadOutputs(
+ max_individuals={
+ "bodyparts": max_individuals,
+ "bboxes": max_individuals,
+ "bbox_scores": max_individuals,
+ },
+ pad_value=-1,
+ ),
+ ]
+ )
+
+
+def build_detector_postprocessor(max_individuals: int) -> Postprocessor:
+ """Creates a postprocessor for top-down pose estimation
+
+ Args:
+ max_individuals: the maximum number of detections to keep in a single image
+
+ Returns:
+ A default top-down Postprocessor
+ """
+ return ComposePostprocessor(
+ components=[
+ ConcatenateOutputs(
+ keys_to_concatenate={
+ "bboxes": ("detection", "bboxes"),
+ "bbox_scores": ("detection", "scores"),
+ }
+ ),
+ TrimOutputs(
+ max_individuals={
+ "bboxes": max_individuals,
+ "bbox_scores": max_individuals,
+ },
+ ),
+ BboxToCoco(bounding_box_keys=["bboxes"]),
+ RescaleAndOffset(
+ keys_to_rescale=["bboxes"],
+ mode=RescaleAndOffset.Mode.BBOX_XYWH,
+ ),
+ ]
+ )
+
+
+class ComposePostprocessor(Postprocessor):
+ """
+ Class to preprocess an image and turn it into a batch of
+ inputs before running inference
+ """
+
+ def __init__(self, components: list[Postprocessor]) -> None:
+ self.components = components
+
+ def __call__(self, predictions: Any, context: Context) -> tuple[Any, Context]:
+ for postprocessor in self.components:
+ predictions, context = postprocessor(predictions, context)
+ return predictions, context
+
+
+class ConcatenateOutputs(Postprocessor):
+ """Checks that there is a single prediction for the image and returns it"""
+
+ def __init__(
+ self,
+ keys_to_concatenate: dict[str, tuple[str, str]],
+ empty_shapes: dict[str, tuple[int, ...]] | None = None,
+ create_empty_outputs: bool = False,
+ ):
+ self.keys_to_concatenate = keys_to_concatenate
+ self.empty_shapes = empty_shapes
+ self.create_empty_outputs = create_empty_outputs
+
+ if self.create_empty_outputs:
+ if not all([k in self.empty_shapes for k in self.keys_to_concatenate]):
+ raise ValueError(
+ "You must provide the expected shape for all keys to concatenate"
+ f" when create_empty_outputs is true, found {self.empty_shapes}"
+ )
+
+ def __call__(
+ self, predictions: Any, context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ if len(predictions) == 0:
+ outputs = {
+ name: np.zeros((0, *self.empty_shapes[name]))
+ for name in self.keys_to_concatenate.keys()
+ }
+ return outputs, context
+
+ outputs = {}
+ for output_name, head_key in self.keys_to_concatenate.items():
+ head_name, val_name = head_key
+ outputs[output_name] = np.concatenate(
+ [p[head_name][val_name] for p in predictions]
+ )
+
+ return outputs, context
+
+
+class PadOutputs(Postprocessor):
+ """Pads the outputs to have the maximum number of individuals"""
+
+ def __init__(
+ self,
+ max_individuals: dict[str, int],
+ pad_value: int,
+ ):
+ self.max_individuals = max_individuals
+ self.pad_value = pad_value
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ for name in predictions:
+ output = predictions[name]
+ if (
+ name in self.max_individuals
+ and len(output) < self.max_individuals[name]
+ ):
+ pad_size = self.max_individuals[name] - len(output)
+ tail_shape = output.shape[1:]
+ padding = self.pad_value * np.ones((pad_size, *tail_shape))
+ predictions[name] = np.concatenate([output, padding])
+
+ return predictions, context
+
+
+class TrimOutputs(Postprocessor):
+ """Ensures all outputs have at most `max_individuals` detections
+
+ Assumes that the outputs are sorted by decreasing score, such that the first
+ `max_individuals` predictions are the ones to keep.
+ """
+
+ def __init__(self, max_individuals: dict[str, int]):
+ self.max_individuals = max_individuals
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ for name in predictions:
+ output = predictions[name]
+ if len(output) > self.max_individuals[name]:
+ predictions[name] = output[:self.max_individuals[name]]
+
+ return predictions, context
+
+
+class RescaleAndOffset(Postprocessor):
+ """Rescales and offsets predictions back to their position in the original image
+
+ This can be done in 3 ways:
+ BBOX_XYWH: the data has shape (num_individuals, 4), in xywh format, and there
+ is a single scale and offset for all bounding boxes (e.g., because the image
+ was resized before being passed to a detector)
+ KEYPOINT: the data has shape (num_individuals, num_keypoints, 2/3), and there
+ is a single scale and offset for all individuals (e.g., because the image
+ was resized before being passed to a BU pose model)
+ KEYPOINT_TD: the data has shape (num_individuals, num_keypoints, 2/3), and there
+ are num_individuals scales and offsets (one for each individual, as TD crops
+ one image per individual)
+
+ If no scale and no offsets are given, then this postprocessor simply forwards the
+ predictions and context.
+ """
+
+ class Mode(Enum):
+ BBOX_XYWH = "bbox_xywh"
+ KEYPOINT = "keypoint"
+ KEYPOINT_TD = "keypoint_td"
+
+ def __init__(
+ self,
+ keys_to_rescale: list[str],
+ mode: RescaleAndOffset.Mode,
+ ) -> None:
+ super().__init__()
+ self.keys_to_rescale = keys_to_rescale
+ self.mode = mode
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ if "scales" not in context and "offsets" not in context:
+ # no rescaling needed
+ return predictions, context
+
+ updated_predictions = {}
+ scales, offsets = context["scales"], context["offsets"]
+ for name, outputs in predictions.items():
+ if name in self.keys_to_rescale:
+ if self.mode == self.Mode.BBOX_XYWH:
+ rescaled = outputs.copy()
+ rescaled[:, 0] = outputs[:, 0] * scales[0] + offsets[0]
+ rescaled[:, 1] = outputs[:, 1] * scales[1] + offsets[1]
+ rescaled[:, 2] = outputs[:, 2] * scales[0]
+ rescaled[:, 3] = outputs[:, 3] * scales[1]
+ elif self.mode == self.Mode.KEYPOINT:
+ rescaled = outputs.copy()
+ rescaled[..., :2] = outputs[..., :2] * scales + offsets
+ else: # Mode.KEYPOINT_TD
+ if not len(outputs) == len(scales) == len(offsets):
+ raise ValueError(
+ "There must be as many 'scales' and 'offsets' as outputs, found "
+ f"{len(outputs)}, {len(scales)}, {len(offsets)}"
+ )
+
+ if len(outputs) == 0:
+ rescaled = outputs
+ else:
+ rescaled_individuals = []
+ for output, scale, offset in zip(outputs, scales, offsets):
+ output_rescaled = output.copy()
+ output_rescaled[:, :2] = output[:, :2] * scale + offset
+ rescaled_individuals.append(output_rescaled)
+ rescaled = np.stack(rescaled_individuals)
+
+ updated_predictions[name] = rescaled
+ else:
+ updated_predictions[name] = outputs.copy()
+
+ return updated_predictions, context
+
+
+class BboxToCoco(Postprocessor):
+ """Transforms bounding boxes from xyxy to COCO format (xywh)"""
+
+ def __init__(self, bounding_box_keys: list[str]) -> None:
+ super().__init__()
+ self.bounding_box_keys = bounding_box_keys
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ for bbox_key in self.bounding_box_keys:
+ predictions[bbox_key][:, 2] -= predictions[bbox_key][:, 0]
+ predictions[bbox_key][:, 3] -= predictions[bbox_key][:, 1]
+
+ return predictions, context
+
+
+class AddContextToOutput(Postprocessor):
+ """
+ Adds items from the context to the output, such as the bounding boxes contained
+ during top-down inference.
+ """
+
+ def __init__(self, keys: list[str]) -> None:
+ super().__init__()
+ self.keys = keys
+
+ def __call__(
+ self,
+ predictions: dict[str, np.ndarray],
+ context: Context,
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ for k in self.keys:
+ if k in context:
+ predictions[k] = context[k].copy()
+ return predictions, context
+
+
+class PredictKeypointIdentities(Postprocessor):
+ """Assigns predicted identities to keypoints
+
+ The identity maps have shape (h, w, num_ids).
+
+ Attributes:
+ identity_key: Key with which to add predicted identities in the predictions dict
+ identity_map_key: Key for the identity maps in the predictions dict
+ pose_key: Key for the bodyparts in the predictions dict
+ keep_id_maps: Whether to keep identity heatmaps in the output dictionary.
+ Setting this value to True can be useful for debugging, but can lead to
+ memory issues when running video analysis on long videos.
+ """
+
+ def __init__(
+ self,
+ identity_key: str,
+ identity_map_key: str,
+ pose_key: str,
+ keep_id_maps: bool = False,
+ ) -> None:
+ self.identity_key = identity_key
+ self.identity_map_key = identity_map_key
+ self.pose_key = pose_key
+ self.keep_id_maps = keep_id_maps
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ pose = predictions[self.pose_key]
+ num_preds, num_keypoints, _ = pose.shape
+
+ identity_heatmap = predictions[self.identity_map_key] # (h, w, num_ids)
+ h, w, num_ids = identity_heatmap.shape
+
+ id_score_matrix = np.zeros((num_preds, num_keypoints, num_ids))
+ for pred_idx, individual_keypoints in enumerate(pose):
+ heatmap_indices = np.rint(individual_keypoints).astype(int)
+ xs = np.clip(heatmap_indices[:, 0], 0, w - 1)
+ ys = np.clip(heatmap_indices[:, 1], 0, h - 1)
+
+ # get the score from each identity heatmap at each predicted keypoint
+ for kpt_idx, (x, y) in enumerate(zip(xs, ys)):
+ id_score_matrix[pred_idx, kpt_idx] = identity_heatmap[y, x, :]
+
+ predictions[self.identity_key] = id_score_matrix
+ if not self.keep_id_maps:
+ # delete the heatmaps as this saves memory
+ id_heatmaps = predictions.pop(self.identity_map_key)
+ del id_heatmaps
+
+ return predictions, context
+
+
+class AssignIndividualIdentities(Postprocessor):
+ """Assigns predicted identities to individuals
+
+ Attributes:
+ identity_key: Key with which to add predicted identities in the predictions dict
+ pose_key: Key for the bodyparts in the predictions dict
+ """
+
+ def __init__(self, identity_key: str, pose_key: str) -> None:
+ self.identity_key = identity_key
+ self.pose_key = pose_key
+
+ def __call__(
+ self, predictions: dict[str, np.ndarray], context: Context
+ ) -> tuple[dict[str, np.ndarray], Context]:
+ map_ = assign_identity(predictions["bodyparts"], predictions["identity_scores"])
+ predictions["bodyparts"] = predictions["bodyparts"][map_]
+ predictions["identity_scores"] = predictions["identity_scores"][map_]
+ return predictions, context
+
+
+class PrepareBackboneFeatures(Postprocessor):
+ """Adds backbone features for each individual and keypoint to the outputs
+
+ Attributes:
+ top_down: Whether the model is a top-down model.
+ """
+
+ def __init__(self, top_down: bool) -> None:
+ self.top_down = top_down
+
+ def __call__(self, predictions: Any, context: Context) -> tuple[Any, Context]:
+ if self.top_down:
+ input_w, input_h = context["top_down_crop_size"]
+ else:
+ input_w, input_h = context["image_size"]
+
+ for pred in predictions:
+ features: np.ndarray = pred["backbone"]["features"]
+ pose: np.ndarray = pred["bodypart"]["poses"]
+
+ # only extract features from valid pose
+ mask = ~np.all((pose < 0) | np.isnan(pose), axis=(1, 2))
+ pose = pose[mask]
+ pred["bodypart"]["poses"] = pose.copy()
+
+ pose = np.nan_to_num(pose, nan=0)
+
+ num_features, h, w = features.shape
+ backbone_stride = input_w / w, input_h / h
+
+ num_preds, num_keypoints, _ = pose.shape
+
+ bodypart_features = np.zeros((num_preds, num_keypoints, num_features))
+ indices = np.rint(pose[..., :2] / backbone_stride).astype(int)
+ indices[..., 0] = np.clip(indices[..., 0], 0, w - 1)
+ indices[..., 1] = np.clip(indices[..., 1], 0, h - 1)
+
+ for idv, idv_indices in enumerate(indices):
+ for kpt, (x, y) in enumerate(idv_indices):
+ # only assign features if the pose was defined
+ if np.sum(x + y) > 0:
+ bodypart_features[idv, kpt] = features[:, y, x]
+
+ pred["backbone"]["bodypart_features"] = bodypart_features
+
+ return predictions, context
diff --git a/deeplabcut/pose_estimation_pytorch/data/preprocessor.py b/deeplabcut/pose_estimation_pytorch/data/preprocessor.py
new file mode 100644
index 0000000000..61df02b019
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/preprocessor.py
@@ -0,0 +1,353 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Helpers to run preprocess data before running inference"""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, TypeVar
+
+import albumentations as A
+import cv2
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.data.image import top_down_crop
+
+
+Image = TypeVar("Image", torch.Tensor, np.ndarray, str, Path)
+Context = TypeVar("Context", dict[str, Any], None)
+
+
+class Preprocessor(ABC):
+ """
+ Class to preprocess an image and turn it into a batch of inputs before running
+ inference.
+
+ As an example, a pre-processor can load an image, use a "bboxes" key from context
+ to crop bounding boxes for individuals (going from a (h, w, 3) array to a
+ (num_individuals, h, w, 3) array), and convert it into a tensor ready for inference.
+ """
+
+ @abstractmethod
+ def __call__(self, image: Image, context: Context) -> tuple[Image, Context]:
+ """Pre-processes an image
+
+ Args:
+ image: an image (containing height, width and channel dimensions) or a
+ batch of images linked to a single input (containing an extra batch
+ dimension)
+ context: the context for this image or batch of images (such as bounding
+ boxes, conditional pose, ...)
+
+ Returns:
+ the pre-processed image (or batch of images) and their context
+ """
+ pass
+
+
+def build_bottom_up_preprocessor(
+ color_mode: str, transform: A.BaseCompose
+) -> Preprocessor:
+ """Creates a preprocessor for bottom-up pose estimation (or object detection)
+
+ Creates a preprocessor that loads an image, runs some transform on it (such as
+ normalization), creates a tensor from the numpy array (going from (h, w, 3) to
+ (3, h, w)) and adds a batch dimension (so the final tensor shape is (1, 3, h, w))
+
+ Args:
+ color_mode: whether to load the image as an RGB or BGR
+ transform: the transform to apply to the image
+
+ Returns:
+ A default bottom-up Preprocessor
+ """
+ return ComposePreprocessor(
+ components=[
+ LoadImage(color_mode),
+ AugmentImage(transform),
+ ToTensor(),
+ ToBatch(),
+ ]
+ )
+
+
+def build_top_down_preprocessor(
+ color_mode: str,
+ transform: A.BaseCompose,
+ top_down_crop_size: tuple[int, int],
+ top_down_crop_margin: int = 0,
+) -> Preprocessor:
+ """Creates a preprocessor for top-down pose estimation
+
+ Creates a preprocessor that loads an image, crops all bounding boxes given as a
+ context (through a "bboxes" key), runs some transforms on each cropped image (such
+ as normalization), creates a tensor from the numpy array (going from
+ (num_ind, h, w, 3) to (num_ind, 3, h, w)).
+
+ Args:
+ color_mode: whether to load the image as an RGB or BGR
+ transform: the transform to apply to the image
+ top_down_crop_size: the (width, height) to resize cropped bboxes to
+ top_down_crop_margin: the margin to add around detected bboxes for the crop
+
+ Returns:
+ A default top-down Preprocessor
+ """
+ return ComposePreprocessor(
+ components=[
+ LoadImage(color_mode),
+ TopDownCrop(output_size=top_down_crop_size, margin=top_down_crop_margin),
+ AugmentImage(transform),
+ ToTensor(),
+ ]
+ )
+
+
+class ComposePreprocessor(Preprocessor):
+ """
+ Class to preprocess an image and turn it into a batch of
+ inputs before running inference
+ """
+
+ def __init__(self, components: list[Preprocessor]) -> None:
+ self.components = components
+
+ def __call__(self, image: Image, context: Context) -> tuple[Image, Context]:
+ for preprocessor in self.components:
+ image, context = preprocessor(image, context)
+ return image, context
+
+
+class LoadImage(Preprocessor):
+ """Loads an image from a file, if not yet loaded"""
+
+ def __init__(self, color_mode: str = "RBG") -> None:
+ self.color_mode = color_mode
+
+ def __call__(self, image: Image, context: Context) -> tuple[np.ndarray, Context]:
+ if isinstance(image, (str, Path)):
+ image_ = cv2.imread(str(image))
+ if self.color_mode == "RGB":
+ image_ = cv2.cvtColor(image_, cv2.COLOR_BGR2RGB)
+ else:
+ image_ = image
+
+ h, w = image_.shape[:2]
+ context["image_size"] = w, h
+ return image_, context
+
+
+class AugmentImage(Preprocessor):
+ """
+
+ Adds an offset and scale key to the context:
+ offset: (x, y) position of the pixel in the top left corner of the augmented
+ image in the original image
+ scale: size of the original image divided by the size of the new image
+
+ This allows to map the position of predictions in the transformed image back to the
+ original image space.
+ p_original = p_transformed * scale + offset
+ p_transformed = (p_original - offset) / scale
+ """
+
+ def __init__(self, transform: A.BaseCompose) -> None:
+ self.transform = transform
+
+ @staticmethod
+ def get_offsets_and_scales(
+ h: int,
+ w: int,
+ output_bboxes: list[tuple[float, float, float, float]],
+ ) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]:
+ offsets, scales = [], []
+ for bbox in output_bboxes:
+ x_origin, y_origin, w_out, h_out = bbox
+ x_scale, y_scale = w / w_out, h / h_out
+ x_offset = -x_origin * x_scale
+ y_offset = -y_origin * y_scale
+ offsets.append((x_offset, y_offset))
+ scales.append((x_scale, y_scale))
+
+ return offsets, scales
+
+ @staticmethod
+ def update_offset(
+ offset: tuple[float, float],
+ scale: tuple[float, float],
+ new_offset: tuple[float, float],
+ ) -> tuple[float, float]:
+ return (
+ scale[0] * new_offset[0] + offset[0],
+ scale[1] * new_offset[1] + offset[1],
+ )
+
+ @staticmethod
+ def update_scale(
+ scale: tuple[float, float], new_scale: tuple[float, float]
+ ) -> tuple[float, float]:
+ return scale[0] * new_scale[0], scale[1] * new_scale[1]
+
+ @staticmethod
+ def update_offsets_and_scales(context, new_offsets, new_scales) -> tuple:
+ """
+ x = x' * scale' + offset'
+ x' = x'' * scale'' + offset''
+ -> x = x'' * (scale' * scale'') + (scale' * offset'' + offset')
+ """
+ # scales and offsets are either both lists or both tuples
+ offsets = context.get("offsets", (0, 0))
+ scales = context.get("scales", (1, 1))
+ if isinstance(offsets, tuple):
+ if isinstance(new_offsets, list):
+ updated_offsets = [
+ AugmentImage.update_offset(offsets, scales, new_offset)
+ for new_offset in new_offsets
+ ]
+ updated_scales = [
+ AugmentImage.update_scale(scales, new_scale)
+ for new_scale in new_scales
+ ]
+ else:
+ if not len(offsets) == len(new_offsets):
+ raise ValueError("Cannot rescale lists when not same length")
+
+ updated_offsets = AugmentImage.update_offset(
+ offsets, scales, new_offsets
+ )
+ updated_scales = AugmentImage.update_scale(scales, new_scales)
+ else:
+ if isinstance(new_offsets, list):
+ if not len(offsets) == len(new_offsets):
+ raise ValueError("Cannot rescale lists when not same length")
+
+ updated_offsets = [
+ AugmentImage.update_offset(offset, scale, new_offset)
+ for offset, scale, new_offset in zip(offsets, scales, new_offsets)
+ ]
+ updated_scales = [
+ AugmentImage.update_scale(scale, new_scale)
+ for scale, new_scale in zip(scales, new_scales)
+ ]
+ else:
+ updated_offsets = [
+ AugmentImage.update_offset(offset, scale, new_offsets)
+ for offset, scale in zip(offsets, scales)
+ ]
+ updated_scales = [
+ AugmentImage.update_scale(scale, new_scales) for scale in scales
+ ]
+ return updated_offsets, updated_scales
+
+ def __call__(self, image: Image, context: Context) -> tuple[np.ndarray, Context]:
+ # If the image is a batch, process each entry
+ if len(image.shape) == 4:
+ batch_size, h, w, _ = image.shape
+ if batch_size == 0:
+ # no images in top-down when no detections
+ offsets, scales = (0, 0), (1, 1)
+ else:
+ transformed = [
+ self.transform(
+ image=img,
+ keypoints=[],
+ class_labels=[],
+ bboxes=[[0, 0, w, h]],
+ bbox_labels=["image"],
+ )
+ for img in image
+ ]
+ image = np.stack([t["image"] for t in transformed])
+ output_bboxes = [t["bboxes"][0] for t in transformed]
+ offsets, scales = self.get_offsets_and_scales(h, w, output_bboxes)
+ else:
+ h, w, _ = image.shape
+ transformed = self.transform(
+ image=image,
+ keypoints=[],
+ class_labels=[],
+ bboxes=[[0, 0, w, h]],
+ bbox_labels=["image"],
+ )
+ image = transformed["image"]
+ output_bboxes = [transformed["bboxes"][0]]
+ offsets, scales = self.get_offsets_and_scales(h, w, output_bboxes)
+ offsets = offsets[0]
+ scales = scales[0]
+
+ offsets, scales = self.update_offsets_and_scales(context, offsets, scales)
+ context["offsets"] = offsets
+ context["scales"] = scales
+ return image, context
+
+
+class ToTensor(Preprocessor):
+ """Transforms lists and numpy arrays into tensors"""
+
+ def __call__(self, image: Image, context: Context) -> tuple[np.ndarray, Context]:
+ image = torch.tensor(image, dtype=torch.float)
+ if len(image.shape) == 4:
+ image = image.permute(0, 3, 1, 2)
+ else:
+ image = image.permute(2, 0, 1)
+ return image, context
+
+
+class ToBatch(Preprocessor):
+ """TODO"""
+
+ def __call__(self, image: Image, context: Context) -> tuple[np.ndarray, Context]:
+ return image.unsqueeze(0), context
+
+
+class TopDownCrop(Preprocessor):
+ """Crops bounding boxes out of images for top-down pose estimation
+
+ Args:
+ output_size: The (width, height) of crops to output
+ margin: The margin to add around detected bounding boxes before cropping
+ """
+
+ def __init__(self, output_size: int | tuple[int, int], margin: int = 0) -> None:
+ if isinstance(output_size, int):
+ output_size = (output_size, output_size)
+
+ self.output_size = output_size
+ self.margin = margin
+
+ def __call__(
+ self, image: np.ndarray, context: Context
+ ) -> tuple[np.ndarray, Context]:
+ """TODO: numpy implementation"""
+ if "bboxes" not in context:
+ raise ValueError(f"Must include bboxes to CropDetections, found {context}")
+
+ images, offsets, scales = [], [], []
+ for bbox in context["bboxes"]:
+ crop, offset, scale = top_down_crop(
+ image, bbox, self.output_size, margin=self.margin
+ )
+ images.append(crop)
+ offsets.append(offset)
+ scales.append(scale)
+
+ context["offsets"] = np.array(offsets)
+ context["scales"] = np.array(scales)
+
+ # can have no bounding boxes if detector made no detections
+ if len(images) == 0:
+ images = np.zeros((0, *image.shape))
+ else:
+ images = np.stack(images, axis=0)
+
+ context["top_down_crop_size"] = self.output_size
+ return images, context
diff --git a/deeplabcut/pose_estimation_pytorch/data/transforms.py b/deeplabcut/pose_estimation_pytorch/data/transforms.py
new file mode 100644
index 0000000000..b11c736498
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/transforms.py
@@ -0,0 +1,671 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import warnings
+from typing import Any, Iterable, Sequence
+
+import albumentations as A
+import cv2
+import numpy as np
+from albumentations.augmentations.geometric import functional as F
+from numpy.typing import NDArray
+from scipy.spatial.distance import pdist, squareform
+from scipy.stats import truncnorm
+
+
+def build_transforms(augmentations: dict) -> A.BaseCompose:
+ transforms = []
+
+ if resize_aug := augmentations.get("resize", False):
+ transforms += build_resize_transforms(resize_aug)
+
+ if (lms_cfg := augmentations.get("longest_max_size")) is not None:
+ transforms.append(A.LongestMaxSize(lms_cfg))
+
+ if hflip_cfg := augmentations.get("hflip"):
+ hflip_proba = 0.5
+ symmetries = None
+ if isinstance(hflip_cfg, float):
+ hflip_proba = hflip_cfg
+ elif isinstance(hflip_cfg, dict):
+ if "p" in hflip_cfg:
+ hflip_proba = float(hflip_cfg["p"])
+
+ if "symmetries" in hflip_cfg:
+ symmetries = []
+ for kpt_a, kpt_b in hflip_cfg["symmetries"]:
+ symmetries.append((int(kpt_a), int(kpt_b)))
+
+ if symmetries is not None:
+ transforms.append(HFlip(symmetries=symmetries, p=hflip_proba))
+ else:
+ warnings.warn(
+ "Be careful! Do not train pose models with horizontal flips if you have"
+ " symmetric keypoints!"
+ )
+ transforms.append(A.HorizontalFlip(p=hflip_proba))
+
+ if (affine := augmentations.get("affine")) is not None:
+ scaling = affine.get("scaling")
+ rotation = affine.get("rotation")
+ translation = affine.get("translation")
+ if rotation is not None:
+ rotation = (-rotation, rotation)
+ if translation is not None:
+ translation = (-translation, translation)
+
+ transforms.append(
+ A.Affine(
+ scale=scaling,
+ rotate=rotation,
+ translate_px=translation,
+ p=affine.get("p", 0.9),
+ keep_ratio=True,
+ )
+ )
+
+ if bbox_tfm := augmentations.get("random_bbox_transform", False):
+ transforms.append(
+ RandomBBoxTransform(
+ shift_factor=bbox_tfm.get("shift_factor", 0.1),
+ shift_prob=bbox_tfm.get("shift_prob", 0.25),
+ scale_factor=bbox_tfm.get("scale_factor", (0.75, 1.25)),
+ scale_prob=bbox_tfm.get("scale_prob", 1.0),
+ p=bbox_tfm.get("p", 1.0),
+ )
+ )
+
+ if crop_sampling := augmentations.get("crop_sampling"):
+ transforms.append(
+ A.PadIfNeeded(
+ min_height=crop_sampling["height"],
+ min_width=crop_sampling["width"],
+ border_mode=cv2.BORDER_CONSTANT,
+ always_apply=True,
+ )
+ )
+ transforms.append(
+ KeypointAwareCrop(
+ crop_sampling["width"],
+ crop_sampling["height"],
+ crop_sampling["max_shift"],
+ crop_sampling["method"],
+ )
+ )
+
+ if augmentations.get("hist_eq", False):
+ transforms.append(A.Equalize(p=0.5))
+ if augmentations.get("motion_blur", False):
+ transforms.append(A.MotionBlur(p=0.5))
+ if augmentations.get("covering", False):
+ transforms.append(
+ CoarseDropout(
+ max_holes=10,
+ max_height=0.05,
+ min_height=0.01,
+ max_width=0.05,
+ min_width=0.01,
+ p=0.5,
+ )
+ )
+ if augmentations.get("elastic_transform", False):
+ transforms.append(ElasticTransform(sigma=5, p=0.5))
+ if augmentations.get("grayscale", False):
+ transforms.append(Grayscale(alpha=(0.5, 1.0)))
+ if noise := augmentations.get("gaussian_noise", False):
+ # TODO inherit custom gaussian transform to support per_channel = 0.5
+ if not isinstance(noise, (int, float)):
+ noise = 0.05 * 255
+ transforms.append(
+ A.GaussNoise(
+ var_limit=(0, noise ** 2),
+ mean=0,
+ per_channel=True,
+ # Albumentations doesn't support per_channel = 0.5
+ p=0.5,
+ )
+ )
+
+ if augmentations.get("auto_padding"):
+ transforms.append(build_auto_padding(**augmentations["auto_padding"]))
+
+ if augmentations.get("normalize_images"):
+ transforms.append(
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ )
+
+ return A.Compose(
+ transforms,
+ keypoint_params=A.KeypointParams(
+ "xy", remove_invisible=False, label_fields=["class_labels"]
+ ),
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+
+
+def build_auto_padding(
+ min_height: int | None = None,
+ min_width: int | None = None,
+ pad_height_divisor: int | None = 1,
+ pad_width_divisor: int | None = 1,
+ position: str = "random", # TODO: Which default to set?
+ border_mode: str = "reflect_101", # TODO: Which default to set?
+ border_value: float | None = None,
+ border_mask_value: float | None = None,
+) -> A.PadIfNeeded:
+ """
+ Create an albumentations PadIfNeeded transform from a config
+
+ Args:
+ min_height: the minimum height of the image
+ min_width: the minimum width of the image
+ pad_height_divisor: if not None, ensures height is dividable by value of this argument
+ pad_width_divisor: if not None, ensures width is dividable by value of this argument
+ position: position of the image, one of the possible PadIfNeeded
+ border_mode: 'constant' or 'reflect_101' (see cv2.BORDER modes)
+ border_value: padding value if border_mode is 'constant'
+ border_mask_value: padding value for mask if border_mode is 'constant'
+
+ Raises:
+ ValueError:
+ Only one of 'min_height' and 'pad_height_divisor' parameters must be set
+ Only one of 'min_width' and 'pad_width_divisor' parameters must be set
+
+ Returns:
+ the auto-padding transform
+ """
+ border_modes = {
+ "constant": cv2.BORDER_CONSTANT,
+ "reflect_101": cv2.BORDER_REFLECT_101,
+ }
+ if border_mode not in border_modes:
+ raise ValueError(
+ f"Unknown border mode for auto_padding: {border_mode} "
+ f"(valid values are: {border_modes.keys()})"
+ )
+
+ return A.PadIfNeeded(
+ min_height=min_height,
+ min_width=min_width,
+ pad_height_divisor=pad_height_divisor,
+ pad_width_divisor=pad_width_divisor,
+ position=position,
+ border_mode=border_modes[border_mode],
+ value=border_value,
+ mask_value=border_mask_value,
+ )
+
+
+def build_resize_transforms(resize_cfg: dict) -> list[A.BasicTransform]:
+ height, width = resize_cfg["height"], resize_cfg["width"]
+
+ transforms = []
+ if resize_cfg.get("keep_ratio", True):
+ transforms.append(KeepAspectRatioResize(width=width, height=height, mode="pad"))
+ transforms.append(
+ A.PadIfNeeded(
+ min_height=height,
+ min_width=width,
+ border_mode=cv2.BORDER_CONSTANT,
+ position=A.PadIfNeeded.PositionType.TOP_LEFT,
+ )
+ )
+ else:
+ transforms.append(A.Resize(height, width))
+ return transforms
+
+
+class HFlip(A.HorizontalFlip):
+ """Horizontal Flip which swaps symmetric keypoints"""
+
+ def __init__(self, symmetries: list[tuple[int, int]], *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._symmetries = {}
+ for i, j in symmetries:
+ self._symmetries[i] = j
+ self._symmetries[j] = i
+
+ def apply_to_keypoints(self, keypoints, **params):
+ swapped_keypoints = [
+ keypoints[self._symmetries.get(kpt_idx, kpt_idx)]
+ for kpt_idx in range(len(keypoints))
+ ]
+ return super().apply_to_keypoints(swapped_keypoints, **params)
+
+
+class KeypointAwareCrop(A.RandomCrop):
+ """Random crop for an image around keypoints
+
+ Args:
+ width: Crop images down to this maximum width.
+ height: Crop images down to this maximum height.
+ max_shift: Maximum allowed shift of the cropping center position
+ as a fraction of the crop size.
+ crop_sampling: Crop centers sampling method. Must be either:
+ "uniform" (randomly over the image),
+ "keypoints" (randomly over the annotated keypoints),
+ "density" (weighing preferentially dense regions of keypoints),
+ "hybrid" (alternating randomly between "uniform" and "density").
+ """
+
+ def __init__(
+ self,
+ width: int,
+ height: int,
+ max_shift: float = 0.4,
+ crop_sampling: str = "hybrid",
+ ):
+ super().__init__(height, width, always_apply=True)
+ # Clamp to 40% of crop size to ensure that at least
+ # the center keypoint remains visible after the offset is applied.
+ self.max_shift = max(0.0, min(max_shift, 0.4))
+ if crop_sampling not in ("uniform", "keypoints", "density", "hybrid"):
+ raise ValueError(
+ f"Invalid sampling {crop_sampling}. Must be "
+ f"either 'uniform', 'keypoints', 'density', or 'hybrid."
+ )
+ self.crop_sampling = crop_sampling
+
+ @staticmethod
+ def calc_n_neighbors(xy: NDArray, radius: float) -> NDArray:
+ d = pdist(xy, "sqeuclidean")
+ mat = squareform(d <= radius * radius, checks=False)
+ return np.sum(mat, axis=0)
+
+ @property
+ def targets_as_params(self) -> list[str]:
+ return ["image", "keypoints"]
+
+ def get_params_dependent_on_targets(self, params: dict[str, Any]) -> dict[str, Any]:
+ img = params["image"]
+ kpts = params["keypoints"]
+ shift_factors = np.random.random(2)
+ shift = self.max_shift * shift_factors * np.array([self.width, self.height])
+ sampling = self.crop_sampling
+ if self.crop_sampling == "hybrid":
+ sampling = np.random.choice(["uniform", "density"])
+ if len(kpts) == 0:
+ sampling = "uniform"
+ if sampling == "uniform":
+ center = np.random.random(2)
+ else:
+ h, w = img.shape[:2]
+ kpts = np.array([[k[0], k[1]] for k in kpts])
+ kpts = kpts[~np.isnan(kpts).all(axis=1)]
+ n_kpts = kpts.shape[0]
+ inds = np.arange(n_kpts)
+ if sampling == "density":
+ # Points located close to one another are sampled preferentially
+ # in order to augment crowded regions.
+ radius = 0.1 * min(h, w)
+ n_neighbors = self.calc_n_neighbors(kpts, radius)
+ # Include keypoints in the count to avoid null probabilities
+ n_neighbors += 1
+ p = n_neighbors / n_neighbors.sum()
+ else:
+ p = np.ones_like(inds) / n_kpts
+ center = kpts[np.random.choice(inds, p=p)]
+ # Shift the crop center in both dimensions by random amounts
+ # and normalize to the original image dimensions.
+ center = (center + shift) / [w, h]
+ center = np.clip(center, 0, np.nextafter(1, 0)) # Clip to 1 exclusive
+ return {"h_start": center[1], "w_start": center[0]}
+
+ def apply_to_keypoints(
+ self,
+ keypoints,
+ **params,
+ ) -> list[tuple[float]]:
+ keypoints = super().apply_to_keypoints(keypoints, **params)
+ new_keypoints = []
+ for kp in keypoints:
+ x, y = kp[:2]
+ if not (0 <= x < self.width and 0 <= y < self.height):
+ kp = list(kp)
+ kp[:2] = np.nan, np.nan
+ kp = tuple(kp)
+ new_keypoints.append(kp)
+ return new_keypoints
+
+ def get_transform_init_args_names(self) -> tuple[str, ...]:
+ return "width", "height", "max_shift", "crop_sampling"
+
+
+class KeepAspectRatioResize(A.DualTransform):
+ """Resizes images while preserving their aspect ratio
+
+ In 'pad' mode, the image will be rescaled to the largest possible size such that
+ it can be padded to the correct size (with PadIfNeeded). So we'll have:
+ output_width <= width, output_height <= height
+
+ In 'crop' mode, the image will be rescaled to the smallest possible size such that
+ it can be cropped to the correct size (with any random crop you want), so:
+ output_width >= width, output_height >= height
+ """
+
+ def __init__(
+ self,
+ width: int,
+ height: int,
+ mode: str = "pad",
+ interpolation: Any = cv2.INTER_LINEAR,
+ p: float = 1.0,
+ always_apply: bool = True,
+ ) -> None:
+ super().__init__(always_apply=always_apply, p=p)
+ self.height = height
+ self.width = width
+ self.mode = mode
+ self.interpolation = interpolation
+
+ def apply(self, img, scale=0, interpolation=cv2.INTER_LINEAR, **params):
+ return A.scale(img, scale, interpolation)
+
+ def apply_to_bbox(self, bbox, **params):
+ # Bounding box coordinates are scale invariant
+ return bbox
+
+ def apply_to_keypoint(self, keypoint, scale=0, **params):
+ keypoint = A.keypoint_scale(keypoint, scale, scale)
+ return keypoint
+
+ @property
+ def targets_as_params(self) -> list[str]:
+ return ["image"]
+
+ def get_params_dependent_on_targets(self, params: dict[str, Any]) -> dict[str, Any]:
+ h, w, _ = params["image"].shape
+ if self.mode == "pad":
+ scale = min(self.height / h, self.width / w)
+ else:
+ scale = max(self.height / h, self.width / w)
+
+ return {"scale": scale}
+
+ def get_transform_init_args_names(self):
+ return "height", "width", "mode", "interpolation"
+
+
+class Grayscale(A.ToGray):
+ def __init__(
+ self,
+ alpha: float | int | tuple[float, float] = 1.0,
+ always_apply: bool = False,
+ p: float = 0.5,
+ ):
+ """
+ Args:
+ alpha: int, float or tuple of floats, optional
+ The alpha value of the new colorspace when overlaid over the
+ old one. A value close to 1.0 means that mostly the new
+ colorspace is visible. A value close to 0.0 means that mostly the
+ old image is visible.
+
+ * If a float, exactly that value will be used.
+ * If a tuple ``(a, b)``, a random value from the range
+ ``a <= x <= b`` will be sampled per image.
+ """
+ super().__init__(always_apply, p)
+ if isinstance(alpha, (float, int)):
+ self._alpha = self._validate_alpha(alpha)
+ elif isinstance(alpha, tuple):
+ if len(alpha) != 2:
+ raise ValueError("`alpha` must be a tuple of two numbers.")
+ self._alpha = tuple([self._validate_alpha(val) for val in alpha])
+ else:
+ raise ValueError("")
+
+ @staticmethod
+ def _validate_alpha(val: float) -> float:
+ if not 0.0 <= val <= 1.0:
+ warnings.warn("`alpha` will be clipped to the interval [0.0, 1.0].")
+ return min(1.0, max(0.0, val))
+
+ @property
+ def alpha(self) -> float:
+ if isinstance(self._alpha, float):
+ return self._alpha
+ return np.random.uniform(*self._alpha)
+
+ def apply(self, img: NDArray, **params) -> NDArray:
+ img_gray = super().apply(img, **params)
+ alpha = self.alpha
+ img_blend = img * (1 - alpha) + img_gray * alpha
+ return img_blend.astype(img.dtype)
+
+
+class ElasticTransform(A.ElasticTransform):
+ def __init__(
+ self,
+ alpha: float = 20.0,
+ sigma: float = 5.0, # As in DLC TF
+ alpha_affine: float = 0.0, # Deactivate affine prior to elastic deformation
+ interpolation: int = cv2.INTER_CUBIC, # As in imgaug
+ border_mode: int = cv2.BORDER_CONSTANT, # As in imgaug
+ value: float | None = None,
+ mask_value: float | None = None,
+ always_apply: bool = False,
+ approximate: bool = True, # Faster by a factor of 2
+ same_dxdy: bool = True, # Here too
+ p: float = 0.5,
+ ):
+ super().__init__(
+ alpha,
+ sigma,
+ alpha_affine,
+ interpolation,
+ border_mode,
+ value,
+ mask_value,
+ always_apply,
+ approximate,
+ same_dxdy,
+ p,
+ )
+ self._neighbor_dist = 3
+ self._neighbor_dist_square = self._neighbor_dist ** 2
+
+ def apply_to_keypoints(
+ self, keypoints: Sequence[float], random_state: int | None = None, **params
+ ) -> list[float]:
+ heatmaps = np.zeros(
+ (params["rows"], params["cols"], len(keypoints)), dtype=np.float32
+ )
+ grid = np.mgrid[: params["rows"], : params["cols"]].transpose((1, 2, 0))
+ kpts = np.array([(k[1], k[0]) for k in keypoints])
+ valid_kpts = np.all(kpts > 0.0, axis=1)
+ dist = ((grid - kpts[:, None, None]) ** 2).sum(axis=3)
+ mask = (dist <= self._neighbor_dist_square) & valid_kpts[:, None, None]
+ heatmaps[mask.transpose(1, 2, 0)] = 1
+
+ heatmaps_aug = F.elastic_transform(
+ heatmaps,
+ self.alpha,
+ self.sigma,
+ self.alpha_affine,
+ cv2.INTER_NEAREST,
+ self.border_mode,
+ self.mask_value,
+ np.random.RandomState(random_state),
+ self.approximate,
+ self.same_dxdy,
+ )
+
+ inds = np.indices(heatmaps_aug.shape[:2])[::-1]
+ mask = np.transpose(heatmaps_aug == 1, (2, 0, 1))
+ # Let's compute the average, rather than the median, coordinates
+ div = np.sum(mask, axis=(1, 2))
+ sum_indices = np.sum(inds[:, None] * mask[None], axis=(2, 3)).T
+ xy = sum_indices / div[:, None]
+ new_keypoints = []
+ for kp, new_coords in zip(keypoints, xy):
+ kp = list(kp)
+ kp[:2] = new_coords
+ new_keypoints.append(tuple(kp))
+ return new_keypoints
+
+
+class CoarseDropout(A.CoarseDropout):
+ def __init__(
+ self,
+ max_holes: int = 8,
+ max_height: int | float = 8,
+ max_width: int | float = 8,
+ min_holes: int | None = None,
+ min_height: int | float | None = None,
+ min_width: int | float | None = None,
+ fill_value: int = 0,
+ mask_fill_value: int | None = None,
+ always_apply: bool = False,
+ p: float = 0.5,
+ ):
+ super().__init__(
+ max_holes,
+ max_height,
+ max_width,
+ min_holes,
+ min_height,
+ min_width,
+ fill_value,
+ mask_fill_value,
+ always_apply,
+ p,
+ )
+
+ def apply_to_bboxes(self, bboxes: Sequence[float], **params) -> list[float]:
+ return list(bboxes)
+
+ def apply_to_keypoints(
+ self,
+ keypoints: Sequence[float],
+ holes: Iterable[tuple[int, int, int, int]] = (),
+ **params,
+ ) -> list[float]:
+ new_keypoints = []
+ for kp in keypoints:
+ in_hole = False
+ for hole in holes:
+ if self._keypoint_in_hole(kp, hole):
+ in_hole = True
+ break
+ if in_hole:
+ kp = list(kp)
+ kp[:2] = np.nan, np.nan
+ kp = tuple(kp)
+ new_keypoints.append(kp)
+ return new_keypoints
+
+ def _keypoint_in_hole(self, keypoint, hole: tuple[int, int, int, int]) -> bool:
+ """Reimplemented from Albumentations as was removed in v1.4.0"""
+ x1, y1, x2, y2 = hole
+ x, y = keypoint[:2]
+ return x1 <= x < x2 and y1 <= y < y2
+
+
+class RandomBBoxTransform(A.DualTransform):
+ """Random jittering for bounding boxes for top-down pose estimation models.
+
+ Implementation based on the mmpose `RandomBBoxTransform`. For more information,
+ see .
+ """
+
+ def __init__(
+ self,
+ shift_factor: float = 0.1,
+ shift_prob: float = 0.25,
+ scale_factor: tuple[float, float] = (0.5, 1.5),
+ scale_prob: float = 1.0,
+ sampling: str = "truncnorm",
+ p: float = 1.0,
+ ):
+ super().__init__(p=p)
+ self.shift_factor = shift_factor
+ self.shift_prob = shift_prob
+ self.scale_factor = scale_factor
+ self.scale_prob = scale_prob
+ self.sampling = sampling
+
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
+ return img
+
+ def apply_to_keypoints(self, keypoints: np.ndarray, **params) -> np.ndarray:
+ return keypoints
+
+ def apply_to_bboxes(self, bboxes, **params):
+ if len(bboxes) == 0:
+ return bboxes
+
+ # Albumentations provides bounding boxes in normalized xyxy format internally
+ bboxes_xyxy = np.asarray(bboxes)
+ bboxes_extra = None
+ if bboxes_xyxy.shape[1] > 4:
+ # can't take from array - may have different dtype
+ bboxes_extra = [bbox[4:] for bbox in bboxes]
+ bboxes_xyxy = bboxes_xyxy[:, :4]
+
+ # sample parameters
+ bboxes_to_scale = np.random.random(len(bboxes)) < self.scale_prob
+ num_bboxes_to_scale = np.sum(bboxes_to_scale).item()
+ scale_factors = np.ones((len(bboxes), 2))
+ if num_bboxes_to_scale > 0:
+ scale_factors[bboxes_to_scale] = self._sample(
+ (num_bboxes_to_scale, 2),
+ low=self.scale_factor[0],
+ high=self.scale_factor[1],
+ )
+
+ bboxes_to_shift = np.random.random(len(bboxes)) < self.shift_prob
+ num_bboxes_to_shift = np.sum(bboxes_to_shift).item()
+ shift_factors = np.zeros((len(bboxes), 2))
+ if num_bboxes_to_shift > 0:
+ shift_factors[bboxes_to_shift] = self._sample(
+ (num_bboxes_to_shift, 2),
+ low=-self.shift_factor,
+ high=self.shift_factor,
+ )
+
+ bbox_wh = bboxes_xyxy[:, 2:] - bboxes_xyxy[:, :2]
+ bbox_cxcy = bboxes_xyxy[:, :2] + (0.5 * bbox_wh)
+
+ # scale + shift bounding boxes
+ bbox_cxcy = bbox_cxcy + (shift_factors * bbox_wh)
+ bbox_wh = bbox_wh * scale_factors
+
+ # convert to xyxy, clip so all bounding boxes are in the image
+ bbox_half_wh = 0.5 * bbox_wh
+ bbox_xyxy = np.empty((len(bboxes), 4))
+ bbox_xyxy[:, :2] = bbox_cxcy - bbox_half_wh
+ bbox_xyxy[:, 2:] = bbox_cxcy + bbox_half_wh
+ bbox_xyxy = np.clip(bbox_xyxy, 0, 1)
+
+ # add the extra information back; tuples for albumentations<=1.4.3
+ bboxes_out = [tuple(bbox) for bbox in bbox_xyxy]
+ if bboxes_extra is not None:
+ bboxes_out = [bbox + extra for bbox, extra in zip(bboxes_out, bboxes_extra)]
+ return bboxes_out
+
+ def get_transform_init_args_names(self):
+ return "shift_factor", "shift_prob", "scale_factor", "scale_prob", "sampling"
+
+ def _sample(
+ self,
+ size: tuple[int, ...],
+ low: float = -1.0,
+ high: float = 1.0,
+ ) -> np.ndarray:
+ if self.sampling == "truncnorm":
+ return truncnorm.rvs(low, high, size=size).astype(np.float32)
+ elif self.sampling == "uniform":
+ delta = high - low
+ return low + (delta * np.random.random(size))
+
+ raise ValueError(f"Unknown sampling: {self.sampling}")
diff --git a/deeplabcut/pose_estimation_pytorch/data/utils.py b/deeplabcut/pose_estimation_pytorch/data/utils.py
new file mode 100644
index 0000000000..68101913ca
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/data/utils.py
@@ -0,0 +1,549 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import reduce, lru_cache
+from pathlib import Path
+
+import albumentations as A
+import numpy as np
+import torch
+from PIL import Image
+from torchvision.ops import box_convert
+from torchvision.transforms import functional as F
+
+
+@lru_cache(maxsize=None)
+def read_image_shape_fast(path: str | Path) -> tuple[int, int, int]:
+ """Blazing fast and does not load the image into memory"""
+ with Image.open(path) as img:
+ width, height = img.size
+ return len(img.getbands()), height, width
+
+
+def bbox_from_keypoints(
+ keypoints: np.ndarray,
+ image_h: int,
+ image_w: int,
+ margin: int,
+) -> np.ndarray:
+ """
+ Computes bounding boxes from keypoints.
+
+ Args:
+ keypoints: (..., num_keypoints, xy) the keypoints from which to get bboxes
+ image_h: the height of the image
+ image_w: the width of the image
+ margin: the bounding box margin
+
+ Returns:
+ the bounding boxes for the keypoints, of shape (..., 4) in the xywh format
+ """
+ squeeze = False
+
+ # we do not estimate bbox on keypoints that have 0 or -1 flag
+ keypoints = np.copy(keypoints)
+ keypoints[keypoints[..., -1] <= 0] = np.nan
+
+ if len(keypoints.shape) == 2:
+ squeeze = True
+ keypoints = np.expand_dims(keypoints, axis=0)
+
+ bboxes = np.full((keypoints.shape[0], 4), np.nan)
+ bboxes[:, :2] = np.nanmin(keypoints[..., :2], axis=1) - margin # X1, Y1
+ bboxes[:, 2:4] = np.nanmax(keypoints[..., :2], axis=1) + margin # X2, Y2
+
+ # can have NaNs if some individuals have no visible keypoints
+ bboxes = np.nan_to_num(bboxes, nan=0)
+
+ bboxes = np.clip(
+ bboxes,
+ a_min=[0, 0, 0, 0],
+ a_max=[image_w, image_h, image_w, image_h],
+ )
+ bboxes[..., 2] = bboxes[..., 2] - bboxes[..., 0] # to width
+ bboxes[..., 3] = bboxes[..., 3] - bboxes[..., 1] # to height
+ if squeeze:
+ return bboxes[0]
+
+ return bboxes
+
+
+def merge_list_of_dicts(
+ list_of_dicts: list[dict], keys_to_include: list[str]
+) -> dict[str, list]:
+ """
+ Flattens a list of dictionaries into a dictionary with the lists concatenated.
+
+ Args:
+ list_of_dicts: the dictionaries to merge
+ keys_to_include: the keys to include in the new dictionary
+
+ Returns:
+ the merged dictionary
+
+ Examples:
+ input:
+ list_of_dicts: [{"id": 0, "num": 1}, {"id": 1, "num": 10}]
+ keys_to_include: ["id", "num"]
+ output:
+ {"id": [0, 1], "num": [1, 10]}
+ """
+ return reduce(
+ lambda acc, d: {
+ key: acc.get(key, []) + [value]
+ for key, value in d.items()
+ if key in keys_to_include
+ },
+ list_of_dicts,
+ defaultdict(list),
+ )
+
+
+def map_image_path_to_id(images: list[dict]) -> dict[str, int]:
+ """
+ Binds the image paths to their respective IDs.
+
+ Args:
+ images: List of dictionaries containing image data in COCO-like format.
+ Each dictionary should have 'file_name' and 'id' keys.
+
+ Returns:
+ A dictionary mapping image paths to their respective IDs.
+
+ Examples:
+ images = [{"file_name": "path/to/image1.jpg", "id": 1}, ...]
+ """
+
+ return {image["file_name"]: image["id"] for image in images}
+
+
+def map_id_to_annotations(annotations: list[dict]) -> dict[int, list[int]]:
+ """
+ Maps image IDs to their corresponding annotation indices.
+
+ Args:
+ annotations: List of dictionaries containing annotation data. Each dictionary
+ should have 'image_id' key.
+
+ Returns:
+ A dictionary mapping image IDs to lists of corresponding annotation indices.
+
+ Examples:
+ annotations = [{"image_id": 1, ...}, ...]
+ """
+
+ annotation_idx_map = defaultdict(list)
+ for idx, annotation in enumerate(annotations):
+ annotation_idx_map[annotation["image_id"]].append(idx)
+
+ return annotation_idx_map
+
+
+def _crop_and_pad_image(
+ image: np.ndarray,
+ coords: tuple[tuple[int, int], tuple[int, int]],
+ output_size: tuple[int, int],
+) -> tuple[np.ndarray, tuple[int, int]]:
+ """
+ Crop the image using the given coordinates and pad the larger dimension to change
+ the aspect ratio.
+
+ Args:
+ image: Image to crop, of shape (height, width, channels).
+ coords: Coordinates for cropping as [(xmin, xmax), (ymin, ymax)].
+ output_size: The (output_h, output_w) that this cropped image will be resized
+ to. Used to compute padding to keep aspect ratios.
+
+ Returns:
+ Cropped (and possibly padded) image
+ Padding (pad_h, pad_w)
+ """
+ cropped_image = image[coords[1][0] : coords[1][1], coords[0][0] : coords[0][1], :]
+
+ crop_h, crop_w, c = cropped_image.shape
+ pad_h, pad_w = 0, 0
+ target_ratio_h = output_size[0] / crop_h
+ target_ratio_w = output_size[1] / crop_w
+
+ if target_ratio_h != target_ratio_w:
+ if crop_h < crop_w:
+ # Pad the height
+ new_h = int(crop_w * output_size[0] / output_size[1])
+ pad_h = new_h - crop_h
+ pad_image = np.zeros((new_h, crop_w, c))
+ y_offset = pad_h // 2
+ pad_image[y_offset : y_offset + crop_h, :] = cropped_image
+ else:
+ # Pad the width
+ new_w = int(crop_h * output_size[1] / output_size[0])
+ pad_w = new_w - crop_w
+ pad_image = np.zeros((crop_h, new_w, c))
+ x_offset = pad_w // 2
+ pad_image[:, x_offset : x_offset + crop_w] = cropped_image
+ else:
+ pad_image = cropped_image
+
+ return pad_image, (pad_h, pad_w)
+
+
+def _crop_and_pad_keypoints(
+ keypoints: np.ndarray, coords: tuple[int, int], pad_size: tuple[int, int]
+):
+ """
+ Adjust the keypoints after cropping and padding.
+
+ Parameters:
+ keypoints: The original keypoints, typically a 2D array of shape (..., 2).
+ coords: The (xmin, ymin) crop coordinates used for cropping the image.
+ pad_size: The padding sizes added to the cropped image, in the format (pad_h, pad_w).
+
+ Returns:
+ Adjusted keypoints.
+ """
+ keypoints[..., 0] -= coords[0]
+ keypoints[..., 1] -= coords[1]
+ keypoints[..., 0] += pad_size[1] // 2
+ keypoints[..., 1] += pad_size[0] // 2
+ return keypoints
+
+
+def _crop_image_keypoints(
+ image, keypoints, coords, output_size
+) -> tuple[np.ndarray, np.ndarray, tuple[int, int], tuple[int, int]]:
+ """TODO: Requires fixing
+ Crop the image based on a given bounding box and resize it to the desired output
+ size. Returns offsets and scales to map keypoints in the resized image to
+ coordinates in the original image:
+
+ x_original = (x_cropped * x_scale) + x_offset
+ y_original = (y_cropped * y_scale) + y_offset
+
+ Args:
+ image: Image to crop, of shape (height, width, channels).
+ coords: Coordinates for cropping as ((xmin, xmax), (ymin, ymax)).
+ output_size: The (h, w) that the cropped image should be resized to.
+
+ Returns:
+ Cropped, possibly padded, and resized image.
+ The position of the keypoints in the cropped, resized image
+ Offsets used for cropping.
+ The offsets to map predicted keypoints back to the original image
+ The scale to map predicted keypoints back to the original image
+ """
+
+ cropped_image, pad_size = _crop_and_pad_image(image, coords, output_size)
+ cropped_keypoints = _crop_and_pad_keypoints(
+ keypoints, (coords[0][0], coords[1][0]), pad_size
+ )
+
+ offsets = (coords[0][0], coords[1][0])
+ scales = [
+ output_size[0] / cropped_image.shape[0],
+ output_size[1] / cropped_image.shape[1],
+ ]
+
+ # TODO: Fix resizing, use OpenCV
+ cropped_resized_image = np.resize(
+ cropped_image, (*output_size, cropped_image.shape[2])
+ )
+
+ cropped_resized_keypoints = np.array(cropped_keypoints) * np.array(scales + [1])
+
+ return cropped_resized_image, cropped_resized_keypoints, offsets, scales
+
+
+def _compute_crop_bounds(
+ bboxes: np.ndarray,
+ image_shape: tuple[int, int, int],
+ remove_empty: bool = True,
+) -> np.ndarray:
+ """
+ Compute the boundaries for cropping an image based on a COCO-format bounding box
+ and image shape by clipping values so the bounding boxes are entirely in the image.
+
+ Args:
+ bboxes: COCO-format bounding box of shape (b, xywh)
+ image_shape: Shape of the image defined as (height, width, channels).
+
+ Returns:
+ The bounding boxes, clipped to be entirely inside the image
+ """
+ h, w = image_shape[:2]
+ # to xyxy
+ bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
+ bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
+ # clip
+ bboxes = np.clip(bboxes, 0, np.array([w, h, w, h]))
+ # to xywh
+ bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
+ bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
+ # filter
+ if remove_empty:
+ squashed_bbox_mask = np.logical_or(bboxes[:, 2] <= 0, bboxes[:, 3] <= 0)
+ bboxes = bboxes[~squashed_bbox_mask]
+ return bboxes
+
+
+def _extract_keypoints_and_bboxes(
+ anns: list[dict],
+ image_shape: tuple[int, int, int],
+ num_joints: int,
+ num_unique_bodyparts: int,
+) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict[str, np.ndarray]]:
+ """
+ Args:
+ anns: COCO-style annotations
+ image_shape: the (h, w, c) shape of the image for which to get annotations
+ num_joints: the number of joints in the annotations
+
+ Returns:
+ keypoints, unique_keypoints, bboxes in xywh format, annotations_merged
+ """
+ keypoints = []
+ original_bboxes = []
+ anns_to_merge = []
+ unique_keypoints = None
+ h, w = image_shape[:2]
+ for i, annotation in enumerate(anns):
+ keypoints_individual = _annotation_to_keypoints(annotation, h, w)
+ if annotation["individual"] != "single":
+ bbox_individual = annotation["bbox"]
+ original_bboxes.append(bbox_individual)
+ keypoints.append(keypoints_individual)
+ anns_to_merge.append(annotation)
+ else:
+ unique_keypoints = keypoints_individual
+
+ if unique_keypoints is None:
+ unique_keypoints = -1 * np.ones((num_unique_bodyparts, 3), dtype=float)
+
+ keypoints = safe_stack(keypoints, (0, num_joints, 3))
+ original_bboxes = safe_stack(original_bboxes, (0, 4))
+ bboxes = _compute_crop_bounds(original_bboxes, image_shape, remove_empty=False)
+
+ # at least 1 visible joint to keep individuals
+ vis_mask = (keypoints[..., 2] > 0).any(axis=1)
+ keypoints = keypoints[vis_mask]
+ bboxes = bboxes[vis_mask]
+
+ keys_to_merge = ["area", "category_id", "iscrowd", "individual_id"]
+ anns_merged = {k: [] for k in keys_to_merge}
+ if len(anns_to_merge) > 0:
+ anns_merged = merge_list_of_dicts(anns_to_merge, keys_to_include=keys_to_merge)
+ anns_merged = {k: np.array(v)[vis_mask] for k, v in anns_merged.items()}
+
+ if len(anns_merged["area"]) != len(keypoints):
+ raise ValueError(f"Missing area values! {anns_merged}, {keypoints.shape}")
+
+ return keypoints, unique_keypoints, bboxes, anns_merged
+
+
+def calc_area_from_keypoints(keypoints: np.ndarray) -> np.ndarray:
+ """
+ Calculate the area from keypoints
+
+ TODO: in the pups benchmark, there are 5 keypoints perfectly aligned so
+ the area is 0.
+ How do we deal with that?
+ Makes more sense to compute the area from the bboxes (they are padded)
+ Below is a temporary fix, which sets a min height and width to 5
+ Suggestion: compute min height/width using labeled data
+
+ Args:
+ keypoints (np.ndarray): array of keypoints
+
+ Returns:
+ np.ndarray: array containing the computed areas based on the keypoints
+ """
+ w = np.maximum(keypoints[:, :, 0].max(axis=1) - keypoints[:, :, 0].min(axis=1), 1)
+ h = np.maximum(keypoints[:, :, 1].max(axis=1) - keypoints[:, :, 1].min(axis=1), 1)
+ return w * h
+
+
+def _annotation_to_keypoints(annotation: dict, h: int, w: int) -> np.array:
+ """
+ Convert the coco annotations into array of keypoints returns the array of the
+ keypoints' visibility. If keypoint is not visible, the value for (x,y) coordinates
+ is set to 0. If the keypoints are outside of the image, they are also set to 0.
+
+ Args:
+ annotation: dictionary containing coco-like annotations with essential
+ `keypoints` field
+ h: the image height
+ w: the image width
+
+ Returns:
+ keypoints: np.array where the first two columns are x and y coordinates of the
+
+ """
+ # we don't mess up visibility flags here
+ return annotation["keypoints"].reshape(-1, 3)
+
+
+def apply_transform(
+ transform: A.BaseCompose,
+ image: np.ndarray,
+ keypoints: np.ndarray,
+ bboxes: np.ndarray,
+ class_labels: list[str],
+) -> dict[str, np.ndarray]:
+ """
+ Applies a transformation to the provided image and keypoints.
+
+ Args:
+ transform: The transformation to apply.
+ image: The input image to which the transformation will be applied.
+ keypoints: List of keypoints to be transformed along with the image. Each keypoint
+ is expected to be a tuple or list with at least three values,
+ where the third value indicates the class label index.
+ bboxes: List of bounding boxes to be transformed along with the image.
+ class_labels: List of class labels corresponding to the keypoints.
+
+ Returns:
+ transformed: A dictionary containing the transformed image and keypoints.
+ """
+
+ if transform:
+ oob_mask = out_of_bounds_keypoints(keypoints, image.shape)
+ transformed = _apply_transform(
+ transform, image, keypoints, bboxes, class_labels
+ )
+
+ transformed["keypoints"] = np.array(transformed["keypoints"])
+
+ # out-of-bound keypoints have visibility flag 0. But we don't touch coordinates
+ if np.sum(oob_mask) > 0:
+ transformed["keypoints"][oob_mask, 2] = 0.0
+
+ out_shape = transformed["image"].shape
+ if len(transformed["keypoints"]) > 0:
+ oob_mask = out_of_bounds_keypoints(transformed["keypoints"], out_shape)
+ # out-of-bound keypoints have visibility flag 0. Don't touch coordinates
+ if np.sum(oob_mask) > 0:
+ transformed["keypoints"][oob_mask, 2] = 0.0
+
+ # TODO: Check that the transformed bboxes are still within the image
+ if len(transformed["bboxes"]) > 0:
+ transformed["bboxes"] = np.array(transformed["bboxes"])
+ else:
+ transformed["bboxes"] = np.zeros(shape=(0, 4))
+
+ else:
+ transformed = {"keypoints": keypoints, "image": image}
+
+ # do we ever need to do this if we had check_keypoints_within_bounds above?
+ # np.nan_to_num(transformed["keypoints"], copy=False, nan=-1)
+ return transformed
+
+
+def _apply_transform(
+ transform: A.BaseCompose,
+ image: np.ndarray,
+ keypoints: np.ndarray,
+ bboxes: np.ndarray,
+ class_labels: list[str],
+) -> dict[str, np.ndarray]:
+ """
+ Applies a transformation to the provided image and keypoints.
+
+ Args:
+ image : np.array or similar image data format
+ The input image to which the transformation will be applied.
+
+ keypoints : list or similar data format
+ List of keypoints to be transformed along with the image. Each keypoint
+ is expected to be a tuple or list with at least three values,
+ where the third value indicates the class label index.
+
+ Returns:
+ dict
+ A dictionary containing the transformed image and keypoints.
+ """
+ transformed = transform(
+ image=image,
+ keypoints=keypoints,
+ class_labels=class_labels,
+ bboxes=bboxes,
+ bbox_labels=np.arange(len(bboxes)),
+ )
+
+ bboxes_out = np.zeros(bboxes.shape)
+ for bbox, bbox_id in zip(transformed["bboxes"], transformed["bbox_labels"]):
+ bboxes_out[bbox_id] = bbox
+
+ transformed["bboxes"] = bboxes_out
+ return transformed
+
+
+def out_of_bounds_keypoints(keypoints: np.ndarray, shape: tuple) -> np.ndarray:
+ """Computes which visible keypoints are outside an image
+
+ Args:
+ keypoints: A (N, 3) shaped array where N is the number of keypoints and each
+ keypoint is represented as (x, y, visibility).
+ shape: A tuple representing the shape or bounds as (height, width).
+
+ Returns:
+ A boolean array of shape (N,) where each element corresponds to whether
+ the respective keypoint is visible (visibility > 0) and outside the image
+ bounds. This mask can be used to set the visibility bit to 0 for keypoints that
+ were kicked off an image due to augmentation.
+ """
+ return (keypoints[..., 2] > 0) & (
+ np.isnan(keypoints[..., 0])
+ | np.isnan(keypoints[..., 1])
+ | (keypoints[..., 0] < 0)
+ | (keypoints[..., 0] > shape[1])
+ | (keypoints[..., 1] < 0)
+ | (keypoints[..., 1] > shape[0])
+ )
+
+
+def pad_to_length(data: np.array, length: int, value: float) -> np.array:
+ """
+ Pads the first dimension of an array with a given value
+
+ Args:
+ data: the array to pad, of shape (l, ...), where l <= length
+ length: the desired length of the tensor
+ value: the value to pad with
+
+ Returns:
+ the padded array of shape (length, ...)
+ """
+ pad_length = length - len(data)
+ if pad_length == 0:
+ return data
+ elif pad_length > 0:
+ padding = value * np.ones((pad_length, *data.shape[1:]), dtype=data.dtype)
+ return np.concatenate([data, padding])
+
+ raise ValueError(f"Cannot pad! data.shape={data.shape} > length={length}")
+
+
+def safe_stack(data: list[np.ndarray], default_shape: tuple[int, ...]) -> np.ndarray:
+ """
+ Stacks a list of arrays if there are any, otherwise returns an array of zeros
+ of a desired shape.
+
+ Args:
+ data: the list of arrays to stack
+ default_shape: the shape of the array to return if the list is empty
+
+ Returns:
+ the stacked data or empty array
+ """
+ if len(data) == 0:
+ return np.zeros(default_shape, dtype=float)
+
+ return np.stack(data, axis=0)
diff --git a/deeplabcut/pose_estimation_pytorch/metrics/__init__.py b/deeplabcut/pose_estimation_pytorch/metrics/__init__.py
new file mode 100644
index 0000000000..117d127147
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/metrics/__init__.py
@@ -0,0 +1,10 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
diff --git a/deeplabcut/pose_estimation_pytorch/metrics/scoring.py b/deeplabcut/pose_estimation_pytorch/metrics/scoring.py
new file mode 100644
index 0000000000..317bc87c13
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/metrics/scoring.py
@@ -0,0 +1,75 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import numpy as np
+import pickle
+from sklearn.metrics import accuracy_score
+
+from deeplabcut.core.crossvalutils import find_closest_neighbors
+from deeplabcut.utils.auxiliaryfunctions import read_config
+
+
+def _match_identity_preds_to_gt(
+ config_path: str, full_pickle_path: str
+) -> tuple[np.ndarray, list]:
+ with open(full_pickle_path, "rb") as f:
+ data = pickle.load(f)
+ metadata = data.pop("metadata")
+ cfg = read_config(config_path)
+ all_ids = cfg["individuals"].copy()
+ all_bpts = cfg["multianimalbodyparts"] * len(all_ids)
+ n_multibodyparts = len(all_bpts)
+ if cfg["uniquebodyparts"]:
+ all_ids += ["single"]
+ all_bpts += cfg["uniquebodyparts"]
+ all_bpts = np.asarray(all_bpts)
+ joints = metadata["all_joints_names"]
+ ids = np.full((len(data), len(all_bpts), 2), np.nan)
+ for i, dict_ in enumerate(data.values()):
+ id_gt, _, df_gt = dict_["groundtruth"]
+ for j, id_ in enumerate(id_gt):
+ if id_.size:
+ ids[i, j, 0] = all_ids.index(id_)
+
+ df = df_gt.unstack("coords").reindex(joints, level="bodyparts")
+ xy_pred = dict_["prediction"]["coordinates"][0]
+ for bpt, xy_gt in df.groupby(level="bodyparts"):
+ inds_gt = np.flatnonzero(np.all(~np.isnan(xy_gt), axis=1))
+ n_joint = joints.index(bpt)
+ xy = xy_pred[n_joint]
+ if inds_gt.size and xy.size:
+ # Pick the predictions closest to ground truth,
+ # rather than the ones the model has most confident in
+ xy_gt_values = xy_gt.iloc[inds_gt].values
+ neighbors = find_closest_neighbors(xy_gt_values, xy, k=3)
+ found = neighbors != -1
+ inds = np.flatnonzero(all_bpts == bpt)
+ id_ = dict_["prediction"]["identity"][n_joint]
+ ids[i, inds[inds_gt[found]], 1] = np.argmax(
+ id_[neighbors[found]], axis=1
+ )
+ ids = ids[:, :n_multibodyparts].reshape((len(data), len(cfg["individuals"]), -1, 2))
+ return ids, list(data)
+
+
+def compute_id_accuracy(ids: np.ndarray, mask_test: np.ndarray) -> np.ndarray:
+ nbpts = ids.shape[2] # ids shape is (n_images, n_individuals, n_bodyparts, 2)
+ accu = np.empty((nbpts, 2))
+ for i in range(nbpts):
+ temp = ids[:, :, i].reshape((-1, 2))
+ valid = np.isfinite(temp).all(axis=1)
+ y_true, y_pred = temp[valid].T
+ mask = np.repeat(mask_test, ids.shape[1])[valid]
+ ac_train = accuracy_score(y_true[~mask], y_pred[~mask])
+ ac_test = accuracy_score(y_true[mask], y_pred[mask])
+ accu[i] = ac_train, ac_test
+ return accu
diff --git a/deeplabcut/pose_estimation_pytorch/models/__init__.py b/deeplabcut/pose_estimation_pytorch/models/__init__.py
new file mode 100644
index 0000000000..6e28f8722c
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/__init__.py
@@ -0,0 +1,23 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.backbones.base import BACKBONES
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ CRITERIONS,
+ LOSS_AGGREGATORS,
+)
+from deeplabcut.pose_estimation_pytorch.models.detectors import DETECTORS
+from deeplabcut.pose_estimation_pytorch.models.heads.base import HEADS
+from deeplabcut.pose_estimation_pytorch.models.model import PoseModel
+from deeplabcut.pose_estimation_pytorch.models.necks.base import NECKS
+from deeplabcut.pose_estimation_pytorch.models.predictors import PREDICTORS
+from deeplabcut.pose_estimation_pytorch.models.target_generators import (
+ TARGET_GENERATORS,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/models/backbones/__init__.py b/deeplabcut/pose_estimation_pytorch/models/backbones/__init__.py
new file mode 100644
index 0000000000..fba0d66b2a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/backbones/__init__.py
@@ -0,0 +1,17 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.backbones.base import (
+ BACKBONES,
+ BaseBackbone,
+)
+from deeplabcut.pose_estimation_pytorch.models.backbones.cspnext import CSPNeXt
+from deeplabcut.pose_estimation_pytorch.models.backbones.hrnet import HRNet
+from deeplabcut.pose_estimation_pytorch.models.backbones.resnet import ResNet, DLCRNet
diff --git a/deeplabcut/pose_estimation_pytorch/models/backbones/base.py b/deeplabcut/pose_estimation_pytorch/models/backbones/base.py
new file mode 100644
index 0000000000..bf2febe9ec
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/backbones/base.py
@@ -0,0 +1,141 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import logging
+import shutil
+from abc import ABC, abstractmethod
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from huggingface_hub import hf_hub_download
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+BACKBONES = Registry("backbones", build_func=build_from_cfg)
+
+
+class BaseBackbone(ABC, nn.Module):
+ """Base Backbone class for pose estimation.
+
+ Attributes:
+ stride: the stride for the backbone
+ freeze_bn_weights: freeze weights of batch norm layers during training
+ freeze_bn_stats: freeze stats of batch norm layers during training
+ """
+
+ def __init__(
+ self,
+ stride: int | float,
+ freeze_bn_weights: bool = True,
+ freeze_bn_stats: bool = True,
+ ):
+ super().__init__()
+ self.stride = stride
+ self.freeze_bn_weights = freeze_bn_weights
+ self.freeze_bn_stats = freeze_bn_stats
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Abstract method for the forward pass through the backbone.
+
+ Args:
+ x: Input tensor of shape (batch_size, channels, height, width).
+
+ Returns:
+ a feature map for the input, of shape (batch_size, c', h', w')
+ """
+ pass
+
+ def freeze_batch_norm_layers(self) -> None:
+ """Freezes batch norm layers
+
+ Running mean + var are always given to F.batch_norm, except when the layer is
+ in `train` mode and track_running_stats is False, see
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html
+ So to 'freeze' the running stats, the only way is to set the layer to "eval"
+ mode.
+ """
+ for module in self.modules():
+ if isinstance(module, nn.BatchNorm2d):
+ if self.freeze_bn_weights:
+ module.weight.requires_grad = False
+ module.bias.requires_grad = False
+ if self.freeze_bn_stats:
+ module.eval()
+
+ def train(self, mode: bool = True) -> None:
+ """Sets the module in training or evaluation mode.
+
+ Args:
+ mode: whether to set training mode (True) or evaluation mode (False)
+ """
+ super().train(mode)
+ if self.freeze_bn_weights or self.freeze_bn_stats:
+ self.freeze_batch_norm_layers()
+
+
+class HuggingFaceWeightsMixin:
+ """Mixin for backbones where the pretrained weights are stored on HuggingFace"""
+
+ def __init__(
+ self,
+ backbone_weight_folder: str | Path | None = None,
+ repo_id: str = "DeepLabCut/DeepLabCut-Backbones",
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ if backbone_weight_folder is None:
+ backbone_weight_folder = Path(__file__).parent / "pretrained_weights"
+ else:
+ backbone_weight_folder = Path(backbone_weight_folder).resolve()
+
+ self.backbone_weight_folder = backbone_weight_folder
+ self.repo_id = repo_id
+
+ def download_weights(self, filename: str, force: bool = False) -> Path:
+ """Downloads the backbone weights from the HuggingFace repo
+
+ Args:
+ filename: The name of the model file to download in the repo.
+ force: Whether to re-download the file if it already exists locally.
+
+ Returns:
+ The path to the model snapshot.
+ """
+ model_path = self.backbone_weight_folder / filename
+ if model_path.exists():
+ if not force:
+ return model_path
+ model_path.unlink()
+
+ logging.info(f"Downloading the pre-trained backbone to {model_path}")
+ self.backbone_weight_folder.mkdir(exist_ok=True, parents=False)
+ output_path = Path(
+ hf_hub_download(
+ self.repo_id, filename, cache_dir=self.backbone_weight_folder
+ )
+ )
+
+ # resolve gets the actual path if the output path is a symlink
+ output_path = output_path.resolve()
+ # move to the target path
+ output_path.rename(model_path)
+
+ # delete downloaded artifacts
+ uid, rid = self.repo_id.split("/")
+ artifact_dir = self.backbone_weight_folder / f"models--{uid}--{rid}"
+ if artifact_dir.exists():
+ shutil.rmtree(artifact_dir)
+
+ return model_path
diff --git a/deeplabcut/pose_estimation_pytorch/models/backbones/cspnext.py b/deeplabcut/pose_estimation_pytorch/models/backbones/cspnext.py
new file mode 100644
index 0000000000..50718940b9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/backbones/cspnext.py
@@ -0,0 +1,207 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Implementation of the CSPNeXt Backbone
+
+Based on the ``mmdetection`` CSPNeXt implementation. For more information, see:
+
+
+For more details about this architecture, see `RTMDet: An Empirical Study of Designing
+Real-Time Object Detectors`: https://arxiv.org/abs/1711.05101.
+"""
+from dataclasses import dataclass
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.backbones.base import (
+ BACKBONES,
+ BaseBackbone,
+ HuggingFaceWeightsMixin,
+)
+from deeplabcut.pose_estimation_pytorch.models.modules.csp import (
+ CSPConvModule,
+ CSPLayer,
+ SPPBottleneck,
+)
+
+
+@dataclass(frozen=True)
+class CSPNeXtLayerConfig:
+ """Configuration for a CSPNeXt layer"""
+ in_channels: int
+ out_channels: int
+ num_blocks: int
+ add_identity: bool
+ use_spp: bool
+
+
+@BACKBONES.register_module
+class CSPNeXt(HuggingFaceWeightsMixin, BaseBackbone):
+ """CSPNeXt Backbone
+
+ Args:
+ model_name: The model variant to build. If ``pretrained==True``, must be one of
+ the variants for which weights are available on HuggingFace (in the
+ `DeepLabCut/DeepLabCut-Backbones` hub, e.g. `cspnext_m`).
+ pretrained: Whether to load pretrained weights for the model.
+ arch: The model architecture to build. Must be one of the keys of the
+ ``CSPNeXt.ARCH`` attribute (e.g. `P5`, `P6`, ...).
+ expand_ratio: Ratio used to adjust the number of channels of the hidden layer.
+ deepen_factor: Number of blocks in each CSP layer is multiplied by this value.
+ widen_factor: Number of channels in each layer is multiplied by this value.
+ out_indices: The branch indices to output. If a tuple of integers, the outputs
+ are returned as a list of tensors. If a single integer, a tensor is returned
+ containing the configured index.
+ channel_attention: Add channel attention to all stages
+ norm_layer: The type of normalization layer to use.
+ activation_fn: The type of activation function to use.
+ **kwargs: BaseBackbone kwargs.
+ """
+
+ ARCH: dict[str, list[CSPNeXtLayerConfig]] = {
+ "P5": [
+ CSPNeXtLayerConfig(64, 128, 3, True, False),
+ CSPNeXtLayerConfig(128, 256, 6, True, False),
+ CSPNeXtLayerConfig(256, 512, 6, True, False),
+ CSPNeXtLayerConfig(512, 1024, 3, False, True),
+ ],
+ "P6": [
+ CSPNeXtLayerConfig(64, 128, 3, True, False),
+ CSPNeXtLayerConfig(128, 256, 6, True, False),
+ CSPNeXtLayerConfig(256, 512, 6, True, False),
+ CSPNeXtLayerConfig(512, 768, 3, True, False),
+ CSPNeXtLayerConfig(768, 1024, 3, False, True),
+ ]
+ }
+
+ def __init__(
+ self,
+ model_name: str = "cspnext_m",
+ pretrained: bool = False,
+ arch: str = "P5",
+ expand_ratio: float = 0.5,
+ deepen_factor: float = 0.67,
+ widen_factor: float = 0.75,
+ out_indices: int | tuple[int, ...] = -1,
+ channel_attention: bool = True,
+ norm_layer: str = "SyncBN",
+ activation_fn: str = "SiLU",
+ **kwargs,
+ ) -> None:
+ super().__init__(stride=32, **kwargs)
+ if arch not in self.ARCH:
+ raise ValueError(
+ f"Unknown `CSPNeXT` architecture: {arch}. Must be one of "
+ f"{self.ARCH.keys()}"
+ )
+
+ self.model_name = model_name
+ self.layer_configs = self.ARCH[arch]
+ self.stem_out_channels = self.layer_configs[0].in_channels
+ self.spp_kernel_sizes = (5, 9, 13)
+
+ # stem has stride 2
+ self.stem = nn.Sequential(
+ CSPConvModule(
+ in_channels=3,
+ out_channels=int(self.stem_out_channels * widen_factor // 2),
+ kernel_size=3,
+ padding=1,
+ stride=2,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ ),
+ CSPConvModule(
+ in_channels=int(self.stem_out_channels * widen_factor // 2),
+ out_channels=int(self.stem_out_channels * widen_factor // 2),
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ ),
+ CSPConvModule(
+ in_channels=int(self.stem_out_channels * widen_factor // 2),
+ out_channels=int(self.stem_out_channels * widen_factor),
+ kernel_size=3,
+ padding=1,
+ stride=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ )
+ self.layers = ["stem"]
+
+ for i, layer_cfg in enumerate(self.layer_configs):
+ layer_cfg: CSPNeXtLayerConfig
+ in_channels = int(layer_cfg.in_channels * widen_factor)
+ out_channels = int(layer_cfg.out_channels * widen_factor)
+ num_blocks = max(round(layer_cfg.num_blocks * deepen_factor), 1)
+ stage = []
+ conv_layer = CSPConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ stage.append(conv_layer)
+ if layer_cfg.use_spp:
+ spp = SPPBottleneck(
+ out_channels,
+ out_channels,
+ kernel_sizes=self.spp_kernel_sizes,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ stage.append(spp)
+
+ csp_layer = CSPLayer(
+ out_channels,
+ out_channels,
+ num_blocks=num_blocks,
+ add_identity=layer_cfg.add_identity,
+ expand_ratio=expand_ratio,
+ channel_attention=channel_attention,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ stage.append(csp_layer)
+ self.add_module(f'stage{i + 1}', nn.Sequential(*stage))
+ self.layers.append(f'stage{i + 1}')
+
+ self.single_output = isinstance(out_indices, int)
+ if self.single_output:
+ if out_indices == -1:
+ out_indices = len(self.layers) - 1
+ out_indices = (out_indices,)
+ self.out_indices = out_indices
+
+ if pretrained:
+ weights_filename = f"{model_name}.pt"
+ weights_path = self.download_weights(weights_filename, force=False)
+ snapshot = torch.load(weights_path, map_location="cpu", weights_only=True)
+ self.load_state_dict(snapshot["state_dict"])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor]:
+ outs = []
+ for i, layer_name in enumerate(self.layers):
+ layer = getattr(self, layer_name)
+ x = layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ if self.single_output:
+ return outs[-1]
+
+ return tuple(outs)
diff --git a/deeplabcut/pose_estimation_pytorch/models/backbones/hrnet.py b/deeplabcut/pose_estimation_pytorch/models/backbones/hrnet.py
new file mode 100644
index 0000000000..3180bb53e4
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/backbones/hrnet.py
@@ -0,0 +1,122 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import timm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.backbones.base import (
+ BACKBONES,
+ BaseBackbone,
+)
+
+
+@BACKBONES.register_module
+class HRNet(BaseBackbone):
+ """HRNet backbone.
+
+ This version returns high-resolution feature maps of size 1/4 * original_image_size.
+ This is obtained using bilinear interpolation and concatenation of all the outputs
+ of the HRNet stages.
+
+ The model outputs 4 branches, with strides 4, 8, 16 and 32.
+
+ Args:
+ stride: The stride of the HRNet. Should always be 4, except for custom models.
+ model_name: Any HRNet variant available through timm (e.g., 'hrnet_w32',
+ 'hrnet_w48'). See timm for more options.
+ pretrained: If True, loads the backbone with ImageNet pretrained weights from
+ timm.
+ interpolate_branches: Needed for DEKR. Instead of returning features from the
+ high-resolution branch, interpolates all other branches to the same shape
+ and concatenates them.
+ increased_channel_count: As described by timm, it "allows grabbing increased
+ channel count features using part of the classification head" (otherwise,
+ the default features are returned).
+ kwargs: BaseBackbone kwargs
+
+ Attributes:
+ model: the HRNet model
+ """
+
+ def __init__(
+ self,
+ stride: int = 4,
+ model_name: str = "hrnet_w32",
+ pretrained: bool = False,
+ interpolate_branches: bool = False,
+ increased_channel_count: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(stride=stride, **kwargs)
+ self.model = _load_hrnet(model_name, pretrained, increased_channel_count)
+ self.interpolate_branches = interpolate_branches
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the HRNet backbone.
+
+ Args:
+ x: Input tensor of shape (batch_size, channels, height, width).
+
+ Returns:
+ the feature map
+
+ Example:
+ >>> import torch
+ >>> from deeplabcut.pose_estimation_pytorch.models.backbones import HRNet
+ >>> backbone = HRNet(model_name='hrnet_w32', pretrained=False)
+ >>> x = torch.randn(1, 3, 256, 256)
+ >>> y = backbone(x)
+ """
+ y_list = self.model(x)
+ if not self.interpolate_branches:
+ return y_list[0]
+
+ x0_h, x0_w = y_list[0].size(2), y_list[0].size(3)
+ x = torch.cat(
+ [
+ y_list[0],
+ F.interpolate(y_list[1], size=(x0_h, x0_w), mode="bilinear"),
+ F.interpolate(y_list[2], size=(x0_h, x0_w), mode="bilinear"),
+ F.interpolate(y_list[3], size=(x0_h, x0_w), mode="bilinear"),
+ ],
+ 1,
+ )
+ return x
+
+
+def _load_hrnet(
+ model_name: str,
+ pretrained: bool,
+ increased_channel_count: bool,
+) -> nn.Module:
+ """Loads a TIMM HRNet model.
+
+ Args:
+ model_name: Any HRNet variant available through timm (e.g., 'hrnet_w32',
+ 'hrnet_w48'). See timm for more options.
+ pretrained: If True, loads the backbone with ImageNet pretrained weights from
+ timm.
+ increased_channel_count: As described by timm, it "allows grabbing increased
+ channel count features using part of the classification head" (otherwise,
+ the default features are returned).
+
+ Returns:
+ the HRNet model
+ """
+ # First stem conv is used for stride 2 features, so only return branches 1-4
+ return timm.create_model(
+ model_name,
+ pretrained=pretrained,
+ features_only=True,
+ feature_location="incre" if increased_channel_count else "",
+ out_indices=(1, 2, 3, 4),
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/models/backbones/resnet.py b/deeplabcut/pose_estimation_pytorch/models/backbones/resnet.py
new file mode 100644
index 0000000000..5103ae64a9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/backbones/resnet.py
@@ -0,0 +1,151 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import timm
+import torch
+import torch.nn as nn
+from torchvision.transforms.functional import resize
+
+from deeplabcut.pose_estimation_pytorch.models.backbones.base import (
+ BACKBONES,
+ BaseBackbone,
+)
+
+
+@BACKBONES.register_module
+class ResNet(BaseBackbone):
+ """ResNet backbone.
+
+ This class represents a typical ResNet backbone for pose estimation.
+
+ Attributes:
+ model: the ResNet model
+ """
+
+ def __init__(
+ self,
+ model_name: str = "resnet50",
+ output_stride: int = 32,
+ pretrained: bool = False,
+ drop_path_rate: float = 0.0,
+ drop_block_rate: float = 0.0,
+ **kwargs,
+ ) -> None:
+ """Initialize the ResNet backbone.
+
+ Args:
+ model_name: Name of the ResNet model to use, e.g., 'resnet50', 'resnet101'
+ output_stride: Output stride of the network, 32, 16, or 8.
+ pretrained: If True, initializes with ImageNet pretrained weights.
+ drop_path_rate: Stochastic depth drop-path rate
+ drop_block_rate: Drop block rate
+ kwargs: BaseBackbone kwargs
+ """
+ super().__init__(stride=output_stride, **kwargs)
+ self.model = timm.create_model(
+ model_name,
+ output_stride=output_stride,
+ pretrained=pretrained,
+ drop_path_rate=drop_path_rate,
+ drop_block_rate=drop_block_rate,
+ )
+ self.model.fc = nn.Identity() # remove the FC layer
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the ResNet backbone.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ Example:
+ >>> import torch
+ >>> from deeplabcut.pose_estimation_pytorch.models.backbones import ResNet
+ >>> backbone = ResNet(model_name='resnet50', pretrained=False)
+ >>> x = torch.randn(1, 3, 256, 256)
+ >>> y = backbone(x)
+
+ Expected Output Shape:
+ If input size is (batch_size, 3, shape_x, shape_y), the output shape
+ will be (batch_size, 3, shape_x//16, shape_y//16)
+ """
+ return self.model.forward_features(x)
+
+
+@BACKBONES.register_module
+class DLCRNet(ResNet):
+ def __init__(
+ self,
+ model_name: str = "resnet50",
+ output_stride: int = 32,
+ pretrained: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(model_name, output_stride, pretrained, **kwargs)
+ self.interm_features = {}
+ self.model.layer1[2].register_forward_hook(self._get_features("bank1"))
+ self.model.layer2[2].register_forward_hook(self._get_features("bank2"))
+ self.conv_block1 = self._make_conv_block(
+ in_channels=512, out_channels=512, kernel_size=3, stride=2
+ )
+ self.conv_block2 = self._make_conv_block(
+ in_channels=512, out_channels=128, kernel_size=1, stride=1
+ )
+ self.conv_block3 = self._make_conv_block(
+ in_channels=256, out_channels=256, kernel_size=3, stride=2
+ )
+ self.conv_block4 = self._make_conv_block(
+ in_channels=256, out_channels=256, kernel_size=3, stride=2
+ )
+ self.conv_block5 = self._make_conv_block(
+ in_channels=256, out_channels=128, kernel_size=1, stride=1
+ )
+
+ def _make_conv_block(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int,
+ momentum: float = 0.001, # (1 - decay)
+ ) -> torch.nn.Sequential:
+ return nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride
+ ),
+ nn.BatchNorm2d(out_channels, momentum=momentum),
+ nn.ReLU(),
+ )
+
+ def _get_features(self, name):
+ def hook(model, input, output):
+ self.interm_features[name] = output.detach()
+
+ return hook
+
+ def forward(self, x):
+ out = super().forward(x)
+
+ # Fuse intermediate features
+ bank_2_s8 = self.interm_features["bank2"]
+ bank_1_s4 = self.interm_features["bank1"]
+ bank_2_s16 = self.conv_block1(bank_2_s8)
+ bank_2_s16 = self.conv_block2(bank_2_s16)
+ bank_1_s8 = self.conv_block3(bank_1_s4)
+ bank_1_s16 = self.conv_block4(bank_1_s8)
+ bank_1_s16 = self.conv_block5(bank_1_s16)
+ # Resizing here is required to guarantee all shapes match, as
+ # Conv2D(..., padding='same') is invalid for strided convolutions.
+ h, w = out.shape[-2:]
+ bank_1_s16 = resize(bank_1_s16, [h, w], antialias=True)
+ bank_2_s16 = resize(bank_2_s16, [h, w], antialias=True)
+
+ return torch.cat((bank_1_s16, bank_2_s16, out), dim=1)
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/__init__.py b/deeplabcut/pose_estimation_pytorch/models/criterions/__init__.py
new file mode 100644
index 0000000000..c1b07634ae
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/__init__.py
@@ -0,0 +1,31 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.criterions.aggregators import (
+ WeightedLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.criterions.base import (
+ CRITERIONS,
+ LOSS_AGGREGATORS,
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.criterions.dekr import (
+ DEKRHeatmapLoss,
+ DEKROffsetLoss,
+)
+from deeplabcut.pose_estimation_pytorch.models.criterions.kl_discrete import (
+ KLDiscreteLoss,
+)
+from deeplabcut.pose_estimation_pytorch.models.criterions.weighted import (
+ WeightedBCECriterion,
+ WeightedHuberCriterion,
+ WeightedMSECriterion,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/aggregators.py b/deeplabcut/pose_estimation_pytorch/models/criterions/aggregators.py
new file mode 100644
index 0000000000..973cabfc63
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/aggregators.py
@@ -0,0 +1,31 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.criterions.base import (
+ BaseLossAggregator,
+ LOSS_AGGREGATORS,
+)
+
+
+@LOSS_AGGREGATORS.register_module
+class WeightedLossAggregator(BaseLossAggregator):
+ def __init__(self, weights: dict[str, float]) -> None:
+ super().__init__()
+ self.weights = weights
+
+ def forward(self, losses: dict[str, torch.Tensor]) -> torch.Tensor:
+ weighted_losses = [
+ weight * losses[loss_name] for loss_name, weight in self.weights.items()
+ ]
+ return torch.mean(torch.stack(weighted_losses))
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/base.py b/deeplabcut/pose_estimation_pytorch/models/criterions/base.py
new file mode 100644
index 0000000000..8520366b7f
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/base.py
@@ -0,0 +1,54 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+LOSS_AGGREGATORS = Registry("loss_aggregators", build_func=build_from_cfg)
+CRITERIONS = Registry("criterions", build_func=build_from_cfg)
+
+
+class BaseCriterion(ABC, nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def forward(
+ self, output: torch.Tensor, target: torch.Tensor, **kwargs
+ ) -> torch.Tensor:
+ """
+ Args:
+ output: the output from which to compute the loss
+ target: the target for the loss
+
+ Returns:
+ the different losses for the module, including one "total_loss" key which
+ is the loss from which to start backpropagation
+ """
+ raise NotImplementedError
+
+
+class BaseLossAggregator(ABC, nn.Module):
+ @abstractmethod
+ def forward(self, losses: dict[str, torch.Tensor]) -> torch.Tensor:
+ """
+ Args:
+ losses: the losses to aggregate
+
+ Returns:
+ the aggregate loss
+ """
+ raise NotImplementedError
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/dekr.py b/deeplabcut/pose_estimation_pytorch/models/criterions/dekr.py
new file mode 100644
index 0000000000..ab18007884
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/dekr.py
@@ -0,0 +1,85 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Loss criterions for DEKR models"""
+from __future__ import annotations
+
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.criterions.base import (
+ BaseCriterion,
+ CRITERIONS,
+)
+
+
+@CRITERIONS.register_module
+class DEKRHeatmapLoss(BaseCriterion):
+ """DEKR Heatmap loss"""
+
+ def forward(
+ self,
+ output: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor | float = 1.0,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ output: the output from which to compute the loss
+ target: the target for the loss
+ weights: the weights for the loss
+
+ Returns:
+ the DEKR offset loss
+ """
+ assert output.size() == target.size()
+ loss = ((output - target) ** 2) * weights
+ return loss.mean(dim=3).mean(dim=2).mean(dim=1).mean(dim=0)
+
+
+@CRITERIONS.register_module
+class DEKROffsetLoss(BaseCriterion):
+ """DEKR Offset loss"""
+
+ def __init__(self, beta: float = 1 / 9):
+ super().__init__()
+ self.beta = beta
+
+ def smooth_l1_loss(self, pred, gt):
+ l1_loss = torch.abs(pred - gt)
+ return torch.where(
+ l1_loss < self.beta,
+ 0.5 * l1_loss ** 2 / self.beta,
+ l1_loss - 0.5 * self.beta,
+ )
+
+ def forward(
+ self,
+ output: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor | float = 1.0,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ output: the output from which to compute the loss
+ target: the target for the loss
+ weights: the weights for the loss
+
+ Returns:
+ the DEKR offset loss
+ """
+ assert output.size() == target.size()
+ num_pos = torch.nonzero(weights > 0).size()[0]
+ loss = self.smooth_l1_loss(output, target) * weights
+ if num_pos == 0:
+ num_pos = 1.0
+ loss = loss.sum() / num_pos
+ return loss
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/kl_discrete.py b/deeplabcut/pose_estimation_pytorch/models/criterions/kl_discrete.py
new file mode 100644
index 0000000000..e36bf78ae6
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/kl_discrete.py
@@ -0,0 +1,86 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""SimCC Discrete KL Divergence loss with Gaussian Label Smoothing.
+
+Can be used for SimCC-type heads. Modified from the `mmpose` implementation. For more
+details, see .
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.criterions.base import (
+ BaseCriterion,
+ CRITERIONS,
+)
+
+
+@CRITERIONS.register_module
+class KLDiscreteLoss(BaseCriterion):
+ """KLDiscrete loss
+
+ Args:
+ beta: Temperature for the softmax.
+ label_softmax: Use softmax on the labels.
+ label_beta: Temperature for the softmax on the labels.
+ use_target_weight: Allows the use a weighted loss for different joints.
+ mask: Indices of masked keypoints.
+ mask_weight: Weight for masked keypoints.
+ """
+
+ def __init__(
+ self,
+ beta: float = 1.0,
+ label_softmax: bool = False,
+ label_beta: float = 10.0,
+ use_target_weight: bool = True,
+ mask: list[int] | None = None,
+ mask_weight: float = 1.0,
+ ):
+ super().__init__()
+ self.beta = beta
+ self.label_softmax = label_softmax
+ self.label_beta = label_beta
+ self.use_target_weight = use_target_weight
+ self.mask = mask
+ self.mask_weight = mask_weight
+
+ self.log_softmax = nn.LogSoftmax(dim=1)
+ self.kl_loss = nn.KLDivLoss(reduction="none")
+
+ def forward(
+ self,
+ output: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor | float = 1.0,
+ **kwargs,
+ ) -> torch.Tensor:
+ n, k, _ = output.shape
+ if self.use_target_weight and isinstance(weights, torch.Tensor):
+ weight = weights.reshape(-1)
+ else:
+ weight = 1.0
+
+ pred = output.reshape(-1, output.size(-1))
+ target = target.reshape(-1, target.size(-1))
+ loss = self.criterion(pred, target).mul(weight)
+ if self.mask is not None:
+ loss = loss.reshape(n, k)
+ loss[:, self.mask] = loss[:, self.mask] * self.mask_weight
+
+ return loss.sum() / k
+
+ def criterion(self, dec_outs, labels):
+ log_pt = self.log_softmax(dec_outs * self.beta)
+ if self.label_softmax:
+ labels = F.softmax(labels * self.label_beta, dim=1)
+ loss = torch.mean(self.kl_loss(log_pt, labels), dim=1)
+ return loss
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/utils.py b/deeplabcut/pose_estimation_pytorch/models/criterions/utils.py
new file mode 100644
index 0000000000..693ec74e09
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/utils.py
@@ -0,0 +1,50 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+
+
+def count_nonzero_elems(
+ losses: torch.Tensor, weights: float | torch.Tensor, per_batch: bool = False
+):
+ """
+ Compute the number of elements in the loss function induced by `weights`.
+ This is a torch implementation of https://github.com/tensorflow/tensorflow/blob/4dacf3f368eb7965e9b5c3bbdd5193986081c3b2/tensorflow/python/ops/losses/losses_impl.py#L89
+
+ Args:
+ losses (Tensor): Tensor of shape [batch_size, d1, ... dN].
+ weights (Tensor): Tensor of shape [], [batch_size] or [batch_size, d1, ... dK], where K < N.
+ per_batch (bool): Whether to return the number of elements per batch or as a sum total.
+
+ Returns:
+ Tensor: The number of present (non-zero) elements in the losses tensor.
+ """
+ if isinstance(weights, float):
+ if weights != 0.0:
+ return losses.numel()
+ else:
+ return torch.tensor(0)
+
+ weights = torch.as_tensor(weights, dtype=torch.float32)
+
+ # Check for non-zero weights and broadcast to match losses
+ present = torch.where(
+ weights == 0.0, torch.zeros_like(weights), torch.ones_like(weights)
+ )
+ present = present.expand_as(losses)
+
+ # Reduce sum across the desired dimensions
+ if per_batch:
+ reduction_dims = tuple(range(1, present.dim()))
+ return torch.sum(present, dim=reduction_dims, keepdim=True)
+ else:
+ return torch.sum(present)
diff --git a/deeplabcut/pose_estimation_pytorch/models/criterions/weighted.py b/deeplabcut/pose_estimation_pytorch/models/criterions/weighted.py
new file mode 100644
index 0000000000..65d5f8f425
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/criterions/weighted.py
@@ -0,0 +1,125 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import utils
+from deeplabcut.pose_estimation_pytorch.models.criterions.base import (
+ BaseCriterion,
+ CRITERIONS,
+)
+
+
+class WeightedCriterion(BaseCriterion):
+ """Base class for weighted criterions"""
+
+ def __init__(self, criterion: nn.Module):
+ super().__init__()
+ self.criterion = criterion
+
+ def forward(
+ self,
+ output: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor | float = 1.0,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ output: predicted tensor
+ target: target tensor
+ weights: weights for each element in the loss calculation. If a float,
+ weights all elements by that value. Defaults to 1.
+
+ Returns:
+ the weighted loss
+ """
+ # shape of loss: (batch_size, n_kpts, heatmap_size, heatmap_size)
+ loss = self.criterion(output, target)
+ n_elems = utils.count_nonzero_elems(loss, weights)
+ if n_elems == 0:
+ n_elems = 1
+
+ return torch.sum(loss * weights) / n_elems
+
+
+@CRITERIONS.register_module
+class WeightedMSECriterion(WeightedCriterion):
+ """
+ Weighted Mean Squared Error (MSE) Loss.
+
+ This loss computes the Mean Squared Error between the prediction and target tensors,
+ but it also incorporates weights to adjust the contribution of each element in the loss
+ calculation. The loss is computed element-wise, and elements with a weight of 0 (masked items)
+ are excluded from the loss calculation.
+ """
+
+ def __init__(self) -> None:
+ super().__init__(nn.MSELoss(reduction="none"))
+
+ def forward(
+ self,
+ output: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor | float = 1.0,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Args:
+ output: predicted tensor
+ target: target tensor
+ weights: weights for each element in the loss calculation. If a float,
+ weights all elements by that value. Defaults to 1.
+
+ Returns:
+ the weighted loss
+ """
+ # shape of loss: (batch_size, n_kpts, h, w)
+ loss = self.criterion(output, target)
+ n_elems = utils.count_nonzero_elems(loss, weights)
+ if n_elems == 0:
+ n_elems = 1
+
+ return torch.sum(loss * weights) / n_elems
+
+
+@CRITERIONS.register_module
+class WeightedHuberCriterion(WeightedCriterion):
+ """
+ Weighted Huber Loss.
+
+ This loss computes the Huber loss between the prediction and target tensors,
+ but it also incorporates weights to adjust the contribution of each element in the loss
+ calculation. The loss is computed element-wise, and elements with a weight of 0 are
+ excluded from the loss calculation.
+ """
+
+ def __init__(self) -> None:
+ super().__init__(nn.HuberLoss(reduction="none"))
+
+
+@CRITERIONS.register_module
+class WeightedBCECriterion(WeightedCriterion):
+ """
+ Weighted Binary Cross Entropy (BCE) Loss.
+
+ This loss computes the Binary Cross Entropy loss between the prediction and target tensors,
+ but it also incorporates weights to adjust the contribution of each element in the loss
+ calculation. The loss is computed element-wise, and elements with a weight of 0 are
+ excluded from the loss calculation.
+ """
+
+ def __init__(self) -> None:
+ super().__init__(nn.BCEWithLogitsLoss(reduction="none"))
diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/__init__.py b/deeplabcut/pose_estimation_pytorch/models/detectors/__init__.py
new file mode 100644
index 0000000000..27f50f345a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/detectors/__init__.py
@@ -0,0 +1,16 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.detectors.base import (
+ DETECTORS,
+ BaseDetector,
+)
+from deeplabcut.pose_estimation_pytorch.models.detectors.fasterRCNN import FasterRCNN
+from deeplabcut.pose_estimation_pytorch.models.detectors.ssd import SSDLite
diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/base.py b/deeplabcut/pose_estimation_pytorch/models/detectors/base.py
new file mode 100644
index 0000000000..198c14ed0b
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/detectors/base.py
@@ -0,0 +1,128 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import logging
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+
+def _build_detector(
+ cfg: dict,
+ weight_init: WeightInitialization | None = None,
+ pretrained: bool = False,
+ **kwargs,
+) -> BaseDetector:
+ """Builds a detector using its configuration file
+
+ Args:
+ cfg: The detector configuration.
+ weight_init: The weight initialization to use.
+ pretrained: Whether COCO pretrained weights should be loaded for the detector
+ **kwargs: Other parameters given by the Registry.
+
+ Returns:
+ the built detector
+ """
+ cfg["pretrained"] = pretrained
+ detector: BaseDetector = build_from_cfg(cfg, **kwargs)
+
+ if weight_init is not None and weight_init.detector_snapshot_path is not None:
+ logging.info(
+ f"Loading detector checkpoint from {weight_init.detector_snapshot_path}"
+ )
+ snapshot = torch.load(weight_init.detector_snapshot_path, map_location="cpu")
+ detector.load_state_dict(snapshot["model"])
+
+ return detector
+
+
+DETECTORS = Registry("detectors", build_func=_build_detector)
+
+
+class BaseDetector(ABC, nn.Module):
+ """
+ Definition of the class BaseDetector object.
+ This is an abstract class defining the common structure and inference for detectors.
+ """
+
+ def __init__(
+ self,
+ freeze_bn_stats: bool = False,
+ freeze_bn_weights: bool = False,
+ pretrained: bool = False,
+ ) -> None:
+ super().__init__()
+ self.freeze_bn_stats = freeze_bn_stats
+ self.freeze_bn_weights = freeze_bn_weights
+ self._pretrained = pretrained
+
+ @abstractmethod
+ def forward(
+ self, x: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None
+ ) -> tuple[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]:
+ """
+ Forward pass of the detector
+
+ Args:
+ x: images to be processed
+ targets: ground-truth boxes present in each images
+
+ Returns:
+ losses: {'loss_name': loss_value}
+ detections: for each of the b images, {"boxes": bounding_boxes}
+ """
+ pass
+
+ @abstractmethod
+ def get_target(self, labels: dict) -> list[dict]:
+ """
+ Get the target for training the detector
+
+ Args:
+ labels: annotations containing keypoints, bounding boxes, etc.
+
+ Returns:
+ list of dictionaries, each representing target information for a single annotation.
+ """
+ pass
+
+ def freeze_batch_norm_layers(self) -> None:
+ """Freezes batch norm layers
+
+ Running mean + var are always given to F.batch_norm, except when the layer is
+ in `train` mode and track_running_stats is False, see
+ https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html
+ So to 'freeze' the running stats, the only way is to set the layer to "eval"
+ mode.
+ """
+ for module in self.modules():
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
+ if self.freeze_bn_weights:
+ module.weight.requires_grad = False
+ module.bias.requires_grad = False
+ if self.freeze_bn_stats:
+ module.eval()
+
+ def train(self, mode: bool = True) -> None:
+ """Sets the module in training or evaluation mode.
+
+ Args:
+ mode: whether to set training mode (True) or evaluation mode (False)
+ """
+ super().train(mode)
+ if self.freeze_bn_weights or self.freeze_bn_stats:
+ self.freeze_batch_norm_layers()
diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/fasterRCNN.py b/deeplabcut/pose_estimation_pytorch/models/detectors/fasterRCNN.py
new file mode 100644
index 0000000000..edfdbe8a23
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/detectors/fasterRCNN.py
@@ -0,0 +1,74 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torchvision.models.detection as detection
+
+from deeplabcut.pose_estimation_pytorch.models.detectors.base import DETECTORS
+from deeplabcut.pose_estimation_pytorch.models.detectors.torchvision import (
+ TorchvisionDetectorAdaptor,
+)
+
+
+@DETECTORS.register_module
+class FasterRCNN(TorchvisionDetectorAdaptor):
+ """A FasterRCNN detector
+
+ Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
+ Ren, Shaoqing, Kaiming He, Ross Girshick, and Jian Sun. "Faster r-cnn: Towards
+ real-time object detection with region proposal networks." Advances in neural
+ information processing systems 28 (2015).
+
+ This class is a wrapper of the torchvision implementation of a FasterRCNN (source:
+ https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py).
+
+ Some of the available FasterRCNN variants (from fastest to most powerful):
+ - fasterrcnn_mobilenet_v3_large_fpn
+ - fasterrcnn_resnet50_fpn
+ - fasterrcnn_resnet50_fpn_v2
+
+ Args:
+ variant: The FasterRCNN variant to use (see all options at
+ https://pytorch.org/vision/stable/models.html#object-detection).
+ pretrained: Whether to load model weights pretrained on COCO
+ box_score_thresh: during inference, only return proposals with a classification
+ score greater than box_score_thresh
+ """
+
+ def __init__(
+ self,
+ freeze_bn_stats: bool = False,
+ freeze_bn_weights: bool = False,
+ variant: str = "fasterrcnn_mobilenet_v3_large_fpn",
+ pretrained: bool = False,
+ box_score_thresh: float = 0.01,
+ ) -> None:
+ if not variant.lower().startswith("fasterrcnn"):
+ raise ValueError(
+ "The version must start with `fasterrcnn`. See available models at "
+ "https://pytorch.org/vision/stable/models.html#object-detection"
+ )
+
+ super().__init__(
+ model=variant,
+ weights=("COCO_V1" if pretrained else None),
+ num_classes=None,
+ freeze_bn_stats=freeze_bn_stats,
+ freeze_bn_weights=freeze_bn_weights,
+ box_score_thresh=box_score_thresh,
+ )
+
+ # Modify the base predictor to output the correct number of classes
+ num_classes = 2
+ in_features = self.model.roi_heads.box_predictor.cls_score.in_features
+ self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(
+ in_features, num_classes
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/ssd.py b/deeplabcut/pose_estimation_pytorch/models/detectors/ssd.py
new file mode 100644
index 0000000000..3c8a254b71
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/detectors/ssd.py
@@ -0,0 +1,70 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torchvision.models.detection as detection
+
+from deeplabcut.pose_estimation_pytorch.models.detectors.base import DETECTORS
+from deeplabcut.pose_estimation_pytorch.models.detectors.torchvision import (
+ TorchvisionDetectorAdaptor,
+)
+
+
+@DETECTORS.register_module
+class SSDLite(TorchvisionDetectorAdaptor):
+ """An SSD object detection model"""
+
+ def __init__(
+ self,
+ freeze_bn_stats: bool = False,
+ freeze_bn_weights: bool = False,
+ pretrained: bool = False,
+ pretrained_from_imagenet: bool = False,
+ box_score_thresh: float = 0.01,
+ ) -> None:
+ model_kwargs = dict(weights_backbone=None)
+ if pretrained_from_imagenet:
+ model_kwargs["weights_backbone"] = "IMAGENET1K_V2"
+
+ super().__init__(
+ model="ssdlite320_mobilenet_v3_large",
+ weights=None,
+ num_classes=2,
+ freeze_bn_stats=freeze_bn_stats,
+ freeze_bn_weights=freeze_bn_weights,
+ box_score_thresh=box_score_thresh,
+ model_kwargs=model_kwargs,
+ )
+
+ if pretrained and not pretrained_from_imagenet:
+ weights = detection.SSDLite320_MobileNet_V3_Large_Weights.verify("COCO_V1")
+ state_dict = weights.get_state_dict(progress=False, check_hash=True)
+ for k, v in state_dict.items():
+ key_parts = k.split(".")
+ if (
+ len(key_parts) == 6
+ and key_parts[0] == "head"
+ and key_parts[1] == "classification_head"
+ and key_parts[2] == "module_list"
+ and key_parts[4] == "1"
+ and key_parts[5] in ("weight", "bias")
+ ):
+ # number of COCO classes: 90 + background (91)
+ # number of DLC classes: 1 + background (2)
+ # -> only keep weights for the background + first class
+
+ # future improvement: find best-suited class for the project
+ # and use those weights, instead of naively taking the first
+ all_classes_size = v.shape[0]
+ two_classes_size = 2 * (all_classes_size // 91)
+ state_dict[k] = v[:two_classes_size]
+
+ self.model.load_state_dict(state_dict)
diff --git a/deeplabcut/pose_estimation_pytorch/models/detectors/torchvision.py b/deeplabcut/pose_estimation_pytorch/models/detectors/torchvision.py
new file mode 100644
index 0000000000..6c700377f7
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/detectors/torchvision.py
@@ -0,0 +1,162 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Module to adapt torchvision detectors for DeepLabCut"""
+from __future__ import annotations
+
+import torch
+import torchvision.models.detection as detection
+
+from deeplabcut.pose_estimation_pytorch.models.detectors.base import (
+ BaseDetector,
+)
+
+
+class TorchvisionDetectorAdaptor(BaseDetector):
+ """An adaptor for torchvision detectors
+
+ This class is an adaptor for torchvision detectors to DeepLabCut detectors. Some of
+ the models (from fastest to most powerful) available are:
+ - ssdlite320_mobilenet_v3_large
+ - fasterrcnn_mobilenet_v3_large_fpn
+ - fasterrcnn_resnet50_fpn_v2
+
+ This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or
+ SSDLite) should be used instead.
+
+ The torchvision implementation does not allow to get both predictions and losses
+ with a single forward pass. Therefore, during evaluation only bounding box metrics
+ (mAP, mAR) are available for the test set. See validation loss issue:
+ - https://discuss.pytorch.org/t/compute-validation-loss-for-faster-rcnn/62333/12
+ - https://stackoverflow.com/a/65347721
+
+ Args:
+ model: The torchvision model to use (see all options at
+ https://pytorch.org/vision/stable/models.html#object-detection).
+ weights: The weights to load for the model. If None, no pre-trained weights are
+ loaded.
+ num_classes: Number of classes that the model should output. If None, the number
+ of classes the model is pre-trained on is used.
+ freeze_bn_stats: Whether to freeze stats for BatchNorm layers.
+ freeze_bn_weights: Whether to freeze weights for BatchNorm layers.
+ box_score_thresh: during inference, only return proposals with a classification
+ score greater than box_score_thresh
+ """
+
+ def __init__(
+ self,
+ model: str,
+ weights: str | None = None,
+ num_classes: int | None = 2,
+ freeze_bn_stats: bool = False,
+ freeze_bn_weights: bool = False,
+ box_score_thresh: float = 0.01,
+ model_kwargs: dict | None = None,
+ ) -> None:
+ super().__init__(
+ freeze_bn_stats=freeze_bn_stats,
+ freeze_bn_weights=freeze_bn_weights,
+ pretrained=weights is not None,
+ )
+
+ # Load the model
+ model_fn = getattr(detection, model)
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ self.model = model_fn(
+ weights=weights,
+ box_score_thresh=box_score_thresh,
+ num_classes=num_classes,
+ **model_kwargs,
+ )
+
+ # See source: https://stackoverflow.com/a/65347721
+ self.model.eager_outputs = lambda losses, detections: (losses, detections)
+
+ def forward(
+ self, x: torch.Tensor, targets: list[dict[str, torch.Tensor]] | None = None
+ ) -> tuple[dict[str, torch.Tensor], list[dict[str, torch.Tensor]]]:
+ """
+ Forward pass of the torchvision detector
+
+ Args:
+ x: images to be processed, of shape (b, c, h, w)
+ targets: ground-truth boxes present in the images
+
+ Returns:
+ losses: {'loss_name': loss_value}
+ detections: for each of the b images, {"boxes": bounding_boxes}
+ """
+ return self.model(x, targets)
+
+ def get_target(self, labels: dict) -> list[dict[str, torch.Tensor]]:
+ """
+ Returns target in a format a torchvision detector can handle
+
+ Args:
+ labels: dict of annotations, must contain the keys:
+ area: tensor containing area information for each annotation
+ labels: tensor containing class labels for each annotation
+ is_crowd: tensor indicating if each annotation is a crowd (1) or not (0)
+ image_id: tensor containing image ids for each annotation
+ boxes: tensor containing bounding box information for each annotation
+
+ Returns:
+ res: list of dictionaries, each representing target information for a single
+ annotation. Each dictionary contains the following keys:
+ 'area'
+ 'labels'
+ 'is_crowd'
+ 'boxes'
+
+ Examples:
+ input:
+ annotations = {
+ "area": torch.Tensor([100, 200]),
+ "labels": torch.Tensor([1, 2]),
+ "is_crowd": torch.Tensor([0, 1]),
+ "boxes": torch.Tensor([[10, 20, 30, 40], [50, 60, 70, 80]])
+ }
+ output:
+ res = [
+ {
+ 'area': tensor([100.]),
+ 'labels': tensor([1]),
+ 'image_id': tensor([1]),
+ 'is_crowd': tensor([0]),
+ 'boxes': tensor([[10., 20., 40., 60.]])
+ },
+ {
+ 'area': tensor([200.]),
+ 'labels': tensor([2]),
+ 'image_id': tensor([1]),
+ 'is_crowd': tensor([1]),
+ 'boxes': tensor([[50., 60., 70., 80.]])
+ }
+ ]
+ """
+ res = []
+ for i, box_ann in enumerate(labels["boxes"]):
+ mask = (box_ann[:, 2] > 0.0) & (box_ann[:, 3] > 0.0)
+ box_ann = box_ann[mask]
+ # bbox format conversion (x, y, w, h) -> (x1, y1, x2, y2)
+ box_ann[:, 2] += box_ann[:, 0]
+ box_ann[:, 3] += box_ann[:, 1]
+ res.append(
+ {
+ "area": labels["area"][i][mask],
+ "labels": labels["labels"][i][mask].long(),
+ "is_crowd": labels["is_crowd"][i][mask].long(),
+ "boxes": box_ann,
+ }
+ )
+
+ return res
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/__init__.py b/deeplabcut/pose_estimation_pytorch/models/heads/__init__.py
new file mode 100644
index 0000000000..4a65c8f84d
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/__init__.py
@@ -0,0 +1,16 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.heads.base import HEADS, BaseHead
+from deeplabcut.pose_estimation_pytorch.models.heads.dekr import DEKRHead
+from deeplabcut.pose_estimation_pytorch.models.heads.dlcrnet import DLCRNetHead
+from deeplabcut.pose_estimation_pytorch.models.heads.rtmcc_head import RTMCCHead
+from deeplabcut.pose_estimation_pytorch.models.heads.simple_head import HeatmapHead
+from deeplabcut.pose_estimation_pytorch.models.heads.transformer import TransformerHead
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/base.py b/deeplabcut/pose_estimation_pytorch/models/heads/base.py
new file mode 100644
index 0000000000..b0d0a8c49f
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/base.py
@@ -0,0 +1,186 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+from deeplabcut.pose_estimation_pytorch.models.weight_init import (
+ BaseWeightInitializer,
+ WEIGHT_INIT,
+)
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+HEADS = Registry("heads", build_func=build_from_cfg)
+
+
+class BaseHead(ABC, nn.Module):
+ """A head for pose estimation models
+
+ Attributes:
+ stride: The stride for the head (or neck + head pair), where positive values
+ indicate an increase in resolution while negative values a decrease.
+ Assuming that H and W are divisible by `stride`, this is the value such
+ that if a backbone outputs an encoding of shape (C, H, W), the head will
+ output heatmaps of shape:
+ (C, H * stride, W * stride) if stride > 0
+ (C, -H/stride, -W/stride) if stride < 0
+ predictor: an object to generate predictions from the head outputs
+ target_generator: a target generator which must output a target for each
+ output key of this module (i.e. if forward returns a "heatmap" tensor and
+ an "offset" tensor, then targets must be generated for both)
+ criterion: either a single criterion (e.g. if this head only outputs heatmaps)
+ or a dictionary mapping the outputs of this head to the criterion to use
+ (e.g. a "heatmap" criterion and an "offset" criterion for DEKR).
+ aggregator: if the criterion is a dictionary, cannot be none. used to combine
+ the individual losses from this head into one "total_loss"
+ """
+
+ def __init__(
+ self,
+ stride: int | float,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion] | BaseCriterion,
+ aggregator: BaseLossAggregator | None = None,
+ weight_init: str | dict | BaseWeightInitializer | None = None,
+ ) -> None:
+ super().__init__()
+ if stride == 0:
+ raise ValueError(f"Stride must not be 0. Found {stride}.")
+
+ self.stride = stride
+ self.predictor = predictor
+ self.target_generator = target_generator
+ self.criterion = criterion
+ self.aggregator = aggregator
+
+ self.weight_init: BaseWeightInitializer | None = None
+ if isinstance(weight_init, BaseWeightInitializer):
+ self.weight_init = weight_init
+ elif isinstance(weight_init, (str, dict)):
+ self.weight_init = WEIGHT_INIT.build(weight_init)
+ elif weight_init is not None:
+ raise ValueError(
+ f"Could not parse ``weight_init`` parameter: {weight_init}."
+ )
+
+ if isinstance(criterion, dict):
+ if aggregator is None:
+ raise ValueError(
+ f"When multiple criterions are defined, a loss aggregator must "
+ "also be given"
+ )
+ else:
+ if aggregator is not None:
+ raise ValueError(
+ f"Cannot use a loss aggregator with a single criterion"
+ )
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ """
+ Given the feature maps for an image ()
+
+ Args:
+ x: the feature maps, of shape (b, c, h, w)
+
+ Returns:
+ the head outputs (e.g. "heatmap", "locref")
+ """
+ pass
+
+ def get_loss(
+ self,
+ outputs: dict[str, torch.Tensor],
+ targets: dict[str, dict[str, torch.Tensor]],
+ ) -> dict[str, torch.Tensor]:
+ """
+ Computes the loss for this head
+
+ Args:
+ outputs: the outputs of this head
+ targets: the targets for this head
+
+ Returns:
+ A dictionary containing minimally "total_loss" key mapping to the total
+ loss for this head (from which backwards() should be called). Can contain
+ other keys containing losses that can be logged for informational purposes.
+ """
+ if self.aggregator is None:
+ assert len(outputs) == len(targets) == 1
+ key = [k for k in outputs.keys()][0]
+ return {"total_loss": self.criterion(outputs[key], **targets[key])}
+
+ losses = {
+ name: criterion(outputs[name], **targets[name])
+ for name, criterion in self.criterion.items()
+ }
+ losses["total_loss"] = self.aggregator(losses)
+ return losses
+
+ def _init_weights(self) -> None:
+ """Should be called once all modules for the class are created"""
+ if self.weight_init is not None:
+ self.weight_init.init_weights(self)
+
+
+class WeightConversionMixin(ABC):
+ """A mixin for heads that can re-order and/or filter the output channels.
+
+ This mixin is useful to convert SuperAnimal model weights such that they can be used
+ in downstream projects (either existing or new), where only a subset of keypoints
+ are available (and where they might be re-ordered).
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+
+ @staticmethod
+ @abstractmethod
+ def convert_weights(
+ state_dict: dict[str, torch.Tensor],
+ module_prefix: str,
+ conversion: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ """Converts pre-trained weights to be fine-tuned on another dataset
+
+ Args:
+ state_dict: the state dict for the pre-trained model
+ module_prefix: the prefix for weights in this head (e.g., 'heads.bodypart.')
+ conversion: the mapping of old indices to new indices
+
+ Examples:
+ A SuperAnimal model was trained on the keypoints ["ear_left", "ear_right",
+ "eye_left", "eye_right", "nose"]. A down-stream project has the bodyparts
+ labeled ["nose", "eye_left", "eye_right"]. The SuperAnimal weights can be
+ converted (to be used with the downstream project) with the following code:
+
+ ``
+ state_dict = torch.load(
+ snapshot_path, map_location=torch.device('cpu')
+ )["model"]
+ state_dict = HeadClass.convert_weights(
+ state_dict,
+ "heads.bodypart",
+ [4, 2, 3]
+ )
+ ``
+ """
+ pass
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/dekr.py b/deeplabcut/pose_estimation_pytorch/models/heads/dekr.py
new file mode 100644
index 0000000000..d61da6a4e9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/dekr.py
@@ -0,0 +1,429 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads.base import BaseHead, HEADS
+from deeplabcut.pose_estimation_pytorch.models.modules.conv_block import (
+ AdaptBlock,
+ BaseBlock,
+ BasicBlock,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+from deeplabcut.pose_estimation_pytorch.models.weight_init import BaseWeightInitializer
+
+
+@HEADS.register_module
+class DEKRHead(BaseHead):
+ """
+ DEKR head based on:
+ Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression
+ Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021
+ Code based on:
+ https://github.com/HRNet/DEKR
+ """
+
+ def __init__(
+ self,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion],
+ aggregator: BaseLossAggregator,
+ heatmap_config: dict,
+ offset_config: dict,
+ weight_init: str | dict | BaseWeightInitializer | None = "dekr",
+ stride: int | float = 1, # head stride - should always be 1 for DEKR
+ ) -> None:
+ super().__init__(
+ stride, predictor, target_generator, criterion, aggregator, weight_init
+ )
+ self.heatmap_head = DEKRHeatmap(**heatmap_config)
+ self.offset_head = DEKROffset(**offset_config)
+ self._init_weights()
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ return {"heatmap": self.heatmap_head(x), "offset": self.offset_head(x)}
+
+
+class DEKRHeatmap(nn.Module):
+ """
+ DEKR head to compute the heatmaps corresponding to keypoints based on:
+ Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression
+ Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021
+ Code based on:
+ https://github.com/HRNet/DEKR
+ """
+
+ def __init__(
+ self,
+ channels: tuple[int],
+ num_blocks: int,
+ dilation_rate: int,
+ final_conv_kernel: int,
+ block: type(BaseBlock) = BasicBlock,
+ ) -> None:
+ """Summary:
+ Constructor of the HeatmapDEKRHead.
+ Loads the data.
+
+ Args:
+ channels: tuple containing the number of channels for the head.
+ num_blocks: number of blocks in the head
+ dilation_rate: dilation rate for the head
+ final_conv_kernel: kernel size for the final convolution
+ block: type of block to use in the head. Defaults to BasicBlock.
+
+ Returns:
+ None
+
+ Examples:
+ channels = (64,128,17)
+ num_blocks = 3
+ dilation_rate = 2
+ final_conv_kernel = 3
+ block = BasicBlock
+ """
+ super().__init__()
+ self.bn_momentum = 0.1
+ self.inp_channels = channels[0]
+ self.num_joints_with_center = channels[
+ 2
+ ] # Should account for the center being a joint
+ self.final_conv_kernel = final_conv_kernel
+
+ self.transition_heatmap = self._make_transition_for_head(
+ self.inp_channels, channels[1]
+ )
+ self.head_heatmap = self._make_heatmap_head(
+ block, num_blocks, channels[1], dilation_rate
+ )
+
+ def _make_transition_for_head(
+ self, in_channels: int, out_channels: int
+ ) -> nn.Sequential:
+ """Summary:
+ Construct the transition layer for the head.
+
+ Args:
+ in_channels: number of input channels
+ out_channels: number of output channels
+
+ Returns:
+ Transition layer consisting of Conv2d, BatchNorm2d, and ReLU
+ """
+ transition_layer = [
+ nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ ]
+ return nn.Sequential(*transition_layer)
+
+ def _make_heatmap_head(
+ self,
+ block: type(BaseBlock),
+ num_blocks: int,
+ num_channels: int,
+ dilation_rate: int,
+ ) -> nn.ModuleList:
+ """Summary:
+ Construct the heatmap head
+
+ Args:
+ block: type of block to use in the head.
+ num_blocks: number of blocks in the head.
+ num_channels: number of input channels for the head.
+ dilation_rate: dilation rate for the head.
+
+ Returns:
+ List of modules representing the heatmap head layers.
+ """
+ heatmap_head_layers = []
+
+ feature_conv = self._make_layer(
+ block, num_channels, num_channels, num_blocks, dilation=dilation_rate
+ )
+ heatmap_head_layers.append(feature_conv)
+
+ heatmap_conv = nn.Conv2d(
+ in_channels=num_channels,
+ out_channels=self.num_joints_with_center,
+ kernel_size=self.final_conv_kernel,
+ stride=1,
+ padding=1 if self.final_conv_kernel == 3 else 0,
+ )
+ heatmap_head_layers.append(heatmap_conv)
+
+ return nn.ModuleList(heatmap_head_layers)
+
+ def _make_layer(
+ self,
+ block: type(BaseBlock),
+ in_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ stride: int = 1,
+ dilation: int = 1,
+ ) -> nn.Sequential:
+ """Summary:
+ Construct a layer in the head.
+
+ Args:
+ block: type of block to use in the head.
+ in_channels: number of input channels for the layer.
+ out_channels: number of output channels for the layer.
+ num_blocks: number of blocks in the layer.
+ stride: stride for the convolutional layer. Defaults to 1.
+ dilation: dilation rate for the convolutional layer. Defaults to 1.
+
+ Returns:
+ Sequential layer containing the specified num_blocks.
+ """
+ downsample = None
+ if stride != 1 or in_channels != out_channels * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ out_channels * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ out_channels * block.expansion, momentum=self.bn_momentum
+ ),
+ )
+
+ layers = [
+ block(in_channels, out_channels, stride, downsample, dilation=dilation)
+ ]
+ in_channels = out_channels * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(block(in_channels, out_channels, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ heatmap = self.head_heatmap[1](self.head_heatmap[0](self.transition_heatmap(x)))
+
+ return heatmap
+
+
+class DEKROffset(nn.Module):
+ """
+ DEKR module to compute the offset from the center corresponding to each keypoints:
+ Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression
+ Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021
+ Code based on:
+ https://github.com/HRNet/DEKR
+ """
+
+ def __init__(
+ self,
+ channels: tuple[int, ...],
+ num_offset_per_kpt: int,
+ num_blocks: int,
+ dilation_rate: int,
+ final_conv_kernel: int,
+ block: type(BaseBlock) = AdaptBlock,
+ ) -> None:
+ """Args:
+ channels: tuple containing the number of input, offset, and output channels.
+ num_offset_per_kpt: number of offset values per keypoint.
+ num_blocks: number of blocks in the head.
+ dilation_rate: dilation rate for convolutional layers.
+ final_conv_kernel: kernel size for the final convolution.
+ block: type of block to use in the head. Defaults to AdaptBlock.
+ """
+ super().__init__()
+ self.inp_channels = channels[0]
+ self.num_joints = channels[2]
+ self.num_joints_with_center = self.num_joints + 1
+
+ self.bn_momentum = 0.1
+ self.offset_perkpt = num_offset_per_kpt
+ self.num_joints_without_center = self.num_joints
+ self.offset_channels = self.offset_perkpt * self.num_joints_without_center
+ assert self.offset_channels == channels[1]
+
+ self.num_blocks = num_blocks
+ self.dilation_rate = dilation_rate
+ self.final_conv_kernel = final_conv_kernel
+
+ self.transition_offset = self._make_transition_for_head(
+ self.inp_channels, self.offset_channels
+ )
+ (
+ self.offset_feature_layers,
+ self.offset_final_layer,
+ ) = self._make_separete_regression_head(
+ block,
+ num_blocks=num_blocks,
+ num_channels_per_kpt=self.offset_perkpt,
+ dilation_rate=self.dilation_rate,
+ )
+
+ def _make_layer(
+ self,
+ block: type(BaseBlock),
+ in_channels: int,
+ out_channels: int,
+ num_blocks: int,
+ stride: int = 1,
+ dilation: int = 1,
+ ) -> nn.Sequential:
+ """Summary:
+ Create a sequential layer with the specified block and number of num_blocks.
+
+ Args:
+ block: block type to use in the layer.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ num_blocks: number of blocks to be stacked in the layer.
+ stride: stride for the first block. Defaults to 1.
+ dilation: dilation rate for the blocks. Defaults to 1.
+
+ Returns:
+ A sequential layer containing stacked num_blocks.
+
+ Examples:
+ input:
+ block=BasicBlock
+ in_channels=64
+ out_channels=128
+ num_blocks=3
+ stride=1
+ dilation=1
+ """
+ downsample = None
+ if stride != 1 or in_channels != out_channels * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ out_channels * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ out_channels * block.expansion, momentum=self.bn_momentum
+ ),
+ )
+
+ layers = []
+ layers.append(
+ block(in_channels, out_channels, stride, downsample, dilation=dilation)
+ )
+ in_channels = out_channels * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(block(in_channels, out_channels, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+ def _make_transition_for_head(
+ self, in_channels: int, out_channels: int
+ ) -> nn.Sequential:
+ """Summary:
+ Create a transition layer for the head.
+
+ Args:
+ in_channels: number of input channels
+ out_channels: number of output channels
+
+ Returns:
+ Sequential layer containing the transition operations.
+ """
+ transition_layer = [
+ nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ ]
+ return nn.Sequential(*transition_layer)
+
+ def _make_separete_regression_head(
+ self,
+ block: type(BaseBlock),
+ num_blocks: int,
+ num_channels_per_kpt: int,
+ dilation_rate: int,
+ ) -> tuple:
+ """Summary:
+
+ Args:
+ block: type of block to use in the head
+ num_blocks: number of blocks in the regression head
+ num_channels_per_kpt: number of channels per keypoint
+ dilation_rate: dilation rate for the regression head
+
+ Returns:
+ A tuple containing two ModuleList objects.
+ The first ModuleList contains the feature convolution layers for each keypoint,
+ and the second ModuleList contains the final offset convolution layers.
+ """
+ offset_feature_layers = []
+ offset_final_layer = []
+
+ for _ in range(self.num_joints):
+ feature_conv = self._make_layer(
+ block,
+ num_channels_per_kpt,
+ num_channels_per_kpt,
+ num_blocks,
+ dilation=dilation_rate,
+ )
+ offset_feature_layers.append(feature_conv)
+
+ offset_conv = nn.Conv2d(
+ in_channels=num_channels_per_kpt,
+ out_channels=2,
+ kernel_size=self.final_conv_kernel,
+ stride=1,
+ padding=1 if self.final_conv_kernel == 3 else 0,
+ )
+ offset_final_layer.append(offset_conv)
+
+ return nn.ModuleList(offset_feature_layers), nn.ModuleList(offset_final_layer)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Summary:
+ Perform forward pass through the OffsetDEKRHead.
+
+ Args:
+ x: input tensor to the head.
+
+ Returns:
+ offset: Computed offsets from the center corresponding to each keypoint.
+ The tensor will have the shape (N, num_joints * 2, H, W), where N is the batch size,
+ num_joints is the number of keypoints, and H and W are the height and width of the output tensor.
+ """
+ final_offset = []
+ offset_feature = self.transition_offset(x)
+
+ for j in range(self.num_joints):
+ final_offset.append(
+ self.offset_final_layer[j](
+ self.offset_feature_layers[j](
+ offset_feature[
+ :, j * self.offset_perkpt : (j + 1) * self.offset_perkpt
+ ]
+ )
+ )
+ )
+
+ offset = torch.cat(final_offset, dim=1)
+
+ return offset
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/dlcrnet.py b/deeplabcut/pose_estimation_pytorch/models/heads/dlcrnet.py
new file mode 100644
index 0000000000..6eeaf68df0
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/dlcrnet.py
@@ -0,0 +1,156 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads.base import HEADS
+from deeplabcut.pose_estimation_pytorch.models.heads.simple_head import (
+ DeconvModule,
+ HeatmapHead,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+from deeplabcut.pose_estimation_pytorch.models.weight_init import BaseWeightInitializer
+
+
+@HEADS.register_module
+class DLCRNetHead(HeatmapHead):
+ """A head for DLCRNet models using Part-Affinity Fields to predict individuals"""
+
+ def __init__(
+ self,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion],
+ aggregator: BaseLossAggregator,
+ heatmap_config: dict,
+ locref_config: dict,
+ paf_config: dict,
+ num_stages: int = 5,
+ features_dim: int = 128,
+ weight_init: str | dict | BaseWeightInitializer | None = None,
+ ) -> None:
+ self.num_stages = num_stages
+ # FIXME Cleaner __init__ to avoid initializing unused layers
+ in_channels = heatmap_config["channels"][0]
+ num_keypoints = heatmap_config["channels"][-1]
+ num_limbs = paf_config["channels"][-1] # Already has the 2x multiplier
+ in_refined_channels = features_dim + num_keypoints + num_limbs
+ if num_stages > 0:
+ heatmap_config["channels"][0] = paf_config["channels"][0] = (
+ in_refined_channels
+ )
+ locref_config["channels"][0] = locref_config["channels"][-1]
+
+ super().__init__(
+ predictor,
+ target_generator,
+ criterion,
+ aggregator,
+ heatmap_config,
+ locref_config,
+ weight_init,
+ )
+ if num_stages > 0:
+ self.stride *= 2 # extra deconv layer where it's multi-stage
+
+ self.paf_head = DeconvModule(**paf_config)
+
+ self.convt1 = self._make_layer_same_padding(
+ in_channels=in_channels, out_channels=num_keypoints
+ )
+ self.convt2 = self._make_layer_same_padding(
+ in_channels=in_channels, out_channels=locref_config["channels"][-1]
+ )
+ self.convt3 = self._make_layer_same_padding(
+ in_channels=in_channels, out_channels=num_limbs
+ )
+ self.convt4 = self._make_layer_same_padding(
+ in_channels=in_channels, out_channels=features_dim
+ )
+ self.hm_ref_layers = nn.ModuleList()
+ self.paf_ref_layers = nn.ModuleList()
+ for _ in range(num_stages):
+ self.hm_ref_layers.append(
+ self._make_refinement_layer(
+ in_channels=in_refined_channels, out_channels=num_keypoints
+ )
+ )
+ self.paf_ref_layers.append(
+ self._make_refinement_layer(
+ in_channels=in_refined_channels, out_channels=num_limbs
+ )
+ )
+ self._init_weights()
+
+ def _make_layer_same_padding(
+ self, in_channels: int, out_channels: int
+ ) -> nn.ConvTranspose2d:
+ # FIXME There is no consensual solution to emulate TF behavior in pytorch
+ # see https://github.com/pytorch/pytorch/issues/3867
+ return nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ )
+
+ def _make_refinement_layer(self, in_channels: int, out_channels: int) -> nn.Conv2d:
+ """Summary:
+ Helper function to create a refinement layer.
+
+ Args:
+ in_channels: number of input channels
+ out_channels: number of output channels
+
+ Returns:
+ refinement_layer: the refinement layer.
+ """
+ return nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding="same"
+ )
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ if self.num_stages > 0:
+ stage1_hm_out = self.convt1(x)
+ stage1_paf_out = self.convt3(x)
+ features = self.convt4(x)
+ stage2_in = torch.cat((stage1_hm_out, stage1_paf_out, features), dim=1)
+ stage_in = stage2_in
+ stage_paf_out = stage1_paf_out
+ stage_hm_out = stage1_hm_out
+ for i, (hm_ref_layer, paf_ref_layer) in enumerate(
+ zip(self.hm_ref_layers, self.paf_ref_layers)
+ ):
+ pre_stage_hm_out = stage_hm_out
+ stage_hm_out = hm_ref_layer(stage_in)
+ stage_paf_out = paf_ref_layer(stage_in)
+ if i > 0:
+ stage_hm_out += pre_stage_hm_out
+ stage_in = torch.cat((stage_hm_out, stage_paf_out, features), dim=1)
+ return {
+ "heatmap": self.heatmap_head(stage_in),
+ "locref": self.locref_head(self.convt2(x)),
+ "paf": self.paf_head(stage_in),
+ }
+ return {
+ "heatmap": self.heatmap_head(x),
+ "locref": self.locref_head(x),
+ "paf": self.paf_head(x),
+ }
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/rtmcc_head.py b/deeplabcut/pose_estimation_pytorch/models/heads/rtmcc_head.py
new file mode 100644
index 0000000000..6b99ec7308
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/rtmcc_head.py
@@ -0,0 +1,162 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Modified SimCC head for the RTMPose model
+
+Based on the official ``mmpose`` RTMCC head implementation. For more information, see
+.
+"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads.base import (
+ BaseHead,
+ HEADS,
+)
+from deeplabcut.pose_estimation_pytorch.models.modules import (
+ GatedAttentionUnit,
+ ScaleNorm,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+from deeplabcut.pose_estimation_pytorch.models.weight_init import BaseWeightInitializer
+
+
+@HEADS.register_module
+class RTMCCHead(BaseHead):
+ """RTMPose Coordinate Classification head
+
+ The RTMCC head is itself adapted from the SimCC head. For more information, see
+ "SimCC: a Simple Coordinate Classification Perspective for Human Pose Estimation"
+ () and "RTMPose: Real-Time Multi-Person Pose
+ Estimation based on MMPose" ().
+
+ Args:
+ input_size: The size of images given to the pose estimation model.
+ in_channels: The number of input channels for the head.
+ out_channels: Number of channels output by the head (number of bodyparts).
+ in_featuremap_size: The size of the input feature map for the head. This is
+ equal to the input_size divided by the backbone stride.
+ simcc_split_ratio: The split ratio of pixels, as described in SimCC.
+ final_layer_kernel_size: Kernel size of the final convolutional layer.
+ gau_cfg: Configuration for the GatedAttentionUnit.
+ predictor: The predictor for the head. Should usually be a `SimCCPredictor`.
+ target_generator: The target generator for the head. Should usually be a
+ `SimCCGenerator`.
+ criterion: The loss criterions for the RTMCC outputs. There should be a
+ criterion for "x" and a criterion for "y".
+ aggregator: The loss aggregator to combine the losses.
+ weight_init: The weight initializer to use for the head.
+ """
+
+ def __init__(
+ self,
+ input_size: tuple[int, int],
+ in_channels: int,
+ out_channels: int,
+ in_featuremap_size: tuple[int, int],
+ simcc_split_ratio: float,
+ final_layer_kernel_size: int,
+ gau_cfg: dict,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion],
+ aggregator: BaseLossAggregator,
+ weight_init: str | dict | BaseWeightInitializer | None = None,
+ ) -> None:
+ super().__init__(
+ 1,
+ predictor,
+ target_generator,
+ criterion,
+ aggregator,
+ weight_init,
+ )
+
+ self.input_size = input_size
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.in_featuremap_size = in_featuremap_size
+ self.simcc_split_ratio = simcc_split_ratio
+
+ flatten_dims = self.in_featuremap_size[0] * self.in_featuremap_size[1]
+ out_w = int(self.input_size[0] * self.simcc_split_ratio)
+ out_h = int(self.input_size[1] * self.simcc_split_ratio)
+
+ self.gau = GatedAttentionUnit(
+ num_token=self.out_channels,
+ in_token_dims=gau_cfg["hidden_dims"],
+ out_token_dims=gau_cfg["hidden_dims"],
+ expansion_factor=gau_cfg["expansion_factor"],
+ s=gau_cfg["s"],
+ eps=1e-5,
+ dropout_rate=gau_cfg["dropout_rate"],
+ drop_path=gau_cfg["drop_path"],
+ attn_type="self-attn",
+ act_fn=gau_cfg["act_fn"],
+ use_rel_bias=gau_cfg["use_rel_bias"],
+ pos_enc=gau_cfg["pos_enc"],
+ )
+
+ self.final_layer = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=final_layer_kernel_size,
+ stride=1,
+ padding=final_layer_kernel_size // 2,
+ )
+ self.mlp = nn.Sequential(
+ ScaleNorm(flatten_dims),
+ nn.Linear(flatten_dims, gau_cfg["hidden_dims"], bias=False),
+ )
+
+ self.cls_x = nn.Linear(gau_cfg["hidden_dims"], out_w, bias=False)
+ self.cls_y = nn.Linear(gau_cfg["hidden_dims"], out_h, bias=False)
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ feats = self.final_layer(x) # -> B, K, H, W
+ feats = torch.flatten(feats, start_dim=2) # -> B, K, hidden=HxW
+ feats = self.mlp(feats) # -> B, K, hidden
+ feats = self.gau(feats)
+ x, y = self.cls_x(feats), self.cls_y(feats)
+ return dict(x=x, y=y)
+
+ @staticmethod
+ def update_input_size(model_cfg: dict, input_size: tuple[int, int]) -> None:
+ """Updates an RTMPose model configuration file for a new image input size
+
+ Args:
+ model_cfg: The model configuration to update in-place.
+ input_size: The updated input (width, height).
+ """
+ _sigmas = {192: 4.9, 256: 5.66, 288: 6, 384: 6.93}
+
+ def _sigma(size: int) -> float:
+ sigma = _sigmas.get(size)
+ if sigma is None:
+ return 2.87 + 0.01 * size
+
+ return sigma
+
+ w, h = input_size
+ model_cfg["data"]["inference"]["top_down_crop"] = dict(width=w, height=h)
+ model_cfg["data"]["train"]["top_down_crop"] = dict(width=w, height=h)
+ head_cfg = model_cfg["model"]["heads"]["bodypart"]
+ head_cfg["input_size"] = input_size
+ head_cfg["in_featuremap_size"] = h // 32, w // 32
+ head_cfg["target_generator"]["input_size"] = input_size
+ head_cfg["target_generator"]["sigma"] = (_sigma(w), _sigma(h))
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/simple_head.py b/deeplabcut/pose_estimation_pytorch/models/heads/simple_head.py
new file mode 100644
index 0000000000..334e674237
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/simple_head.py
@@ -0,0 +1,258 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ BaseCriterion,
+ BaseLossAggregator,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads.base import (
+ BaseHead,
+ HEADS,
+ WeightConversionMixin,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+from deeplabcut.pose_estimation_pytorch.models.weight_init import BaseWeightInitializer
+
+
+@HEADS.register_module
+class HeatmapHead(WeightConversionMixin, BaseHead):
+ """Deconvolutional head to predict maps from the extracted features.
+
+ This class implements a simple deconvolutional head to predict maps from the
+ extracted features.
+
+ Args:
+ predictor: The predictor used to transform heatmaps into keypoints.
+ target_generator: The module to generate target heatmaps from keypoints.
+ criterion: The loss criterion(s) for the head.
+ aggregator: The loss aggregator to use, if multiple criterions are used.
+ heatmap_config: The configuration for the heatmap outputs of the head.
+ locref_config: The configuration for the location refinement outputs (None if
+ no location refinement should be used).
+ weight_init: The way to initialize weights for the head. If None, default
+ PyTorch initialization is used. Otherwise, a BaseWeightInitializer can be
+ given (or a configuration for a BaseWeightInitializer). To initialize
+ the weights with a normal distribution, you could pass
+ ``weight_init="normal"`` (which initializes weights using a Normal
+ distribution 0.001 and biases with 0), or you could pass ``weight_init={
+ type="normal", std=0.01}`` to change the standard deviation used. All
+ BaseWeightInitializers are defined in deeplabcut/pose_estimation_pytorch/
+ models/weight_init.py.
+ """
+
+ def __init__(
+ self,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: dict[str, BaseCriterion] | BaseCriterion,
+ aggregator: BaseLossAggregator | None,
+ heatmap_config: dict,
+ locref_config: dict | None = None,
+ weight_init: str | dict | BaseWeightInitializer | None = None,
+ ) -> None:
+ heatmap_head = DeconvModule(**heatmap_config)
+ locref_head = None
+ if locref_config is not None:
+ locref_head = DeconvModule(**locref_config)
+
+ # check that the heatmap and locref modules have the same stride
+ if heatmap_head.stride != locref_head.stride:
+ raise ValueError(
+ f"Invalid model config: Your heatmap and locref need to have the "
+ f"same stride (found {heatmap_head.stride}, "
+ f"{locref_head.stride}). Please check your config (found "
+ f"heatmap_config={heatmap_config}, locref_config={locref_config}"
+ )
+
+ super().__init__(
+ heatmap_head.stride,
+ predictor,
+ target_generator,
+ criterion,
+ aggregator,
+ weight_init,
+ )
+ self.heatmap_head = heatmap_head
+ self.locref_head = locref_head
+ self._init_weights()
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ outputs = {"heatmap": self.heatmap_head(x)}
+ if self.locref_head is not None:
+ outputs["locref"] = self.locref_head(x)
+ return outputs
+
+ @staticmethod
+ def convert_weights(
+ state_dict: dict[str, torch.Tensor],
+ module_prefix: str,
+ conversion: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ """Converts pre-trained weights to be fine-tuned on another dataset
+
+ Args:
+ state_dict: the state dict for the pre-trained model
+ module_prefix: the prefix for weights in this head (e.g., 'heads.bodypart.')
+ conversion: the mapping of old indices to new indices
+ """
+ state_dict = DeconvModule.convert_weights(
+ state_dict,
+ f"{module_prefix}heatmap_head.",
+ conversion,
+ )
+
+ locref_conversion = torch.stack(
+ [2 * conversion, 2 * conversion + 1],
+ dim=1,
+ ).reshape(-1)
+ state_dict = DeconvModule.convert_weights(
+ state_dict,
+ f"{module_prefix}locref_head.",
+ locref_conversion,
+ )
+ return state_dict
+
+
+class DeconvModule(nn.Module):
+ """
+ Deconvolutional module to predict maps from the extracted features.
+ """
+
+ def __init__(
+ self,
+ channels: list[int],
+ kernel_size: list[int],
+ strides: list[int],
+ final_conv: dict | None = None,
+ ) -> None:
+ """
+ Args:
+ channels: List containing the number of input and output channels for each
+ deconvolutional layer.
+ kernel_size: List containing the kernel size for each deconvolutional layer.
+ strides: List containing the stride for each deconvolutional layer.
+ final_conv: Configuration for a conv layer after the deconvolutional layers,
+ if one should be added. Must have keys "out_channels" and "kernel_size".
+ """
+ super().__init__()
+ if not (len(channels) == len(kernel_size) + 1 == len(strides) + 1):
+ raise ValueError(
+ "Incorrect DeconvModule configuration: there should be one more number"
+ f" of channels than kernel_sizes and strides, found {len(channels)} "
+ f"channels, {len(kernel_size)} kernels and {len(strides)} strides."
+ )
+
+ in_channels = channels[0]
+ head_stride = 1
+ self.deconv_layers = nn.Identity()
+ if len(kernel_size) > 0:
+ self.deconv_layers = nn.Sequential(
+ *self._make_layers(in_channels, channels[1:], kernel_size, strides)
+ )
+ for s in strides:
+ head_stride *= s
+
+ self.stride = head_stride
+ self.final_conv = nn.Identity()
+ if final_conv:
+ self.final_conv = nn.Conv2d(
+ in_channels=channels[-1],
+ out_channels=final_conv["out_channels"],
+ kernel_size=final_conv["kernel_size"],
+ stride=1,
+ )
+
+ @staticmethod
+ def _make_layers(
+ in_channels: int,
+ out_channels: list[int],
+ kernel_sizes: list[int],
+ strides: list[int],
+ ) -> list[nn.Module]:
+ """
+ Helper function to create the deconvolutional layers.
+
+ Args:
+ in_channels: number of input channels to the module
+ out_channels: number of output channels of each layer
+ kernel_sizes: size of the deconvolutional kernel
+ strides: stride for the convolution operation
+
+ Returns:
+ the deconvolutional layers
+ """
+ layers = []
+ for out_channels, k, s in zip(out_channels, kernel_sizes, strides):
+ layers.append(
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=k, stride=s)
+ )
+ layers.append(nn.ReLU())
+ in_channels = out_channels
+ return layers[:-1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass of the HeatmapHead
+
+ Args:
+ x: input tensor
+
+ Returns:
+ out: output tensor
+ """
+ x = self.deconv_layers(x)
+ x = self.final_conv(x)
+ return x
+
+ @staticmethod
+ def convert_weights(
+ state_dict: dict[str, torch.Tensor],
+ module_prefix: str,
+ conversion: torch.Tensor,
+ ) -> dict[str, torch.Tensor]:
+ """Converts pre-trained weights to be fine-tuned on another dataset
+
+ Args:
+ state_dict: the state dict for the pre-trained model
+ module_prefix: the prefix for weights in this head (e.g., 'heads.bodypart')
+ conversion: the mapping of old indices to new indices
+ """
+ if f"{module_prefix}final_conv.weight" in state_dict:
+ # has final convolution
+ weight_key = f"{module_prefix}final_conv.weight"
+ bias_key = f"{module_prefix}final_conv.bias"
+ state_dict[weight_key] = state_dict[weight_key][conversion]
+ state_dict[bias_key] = state_dict[bias_key][conversion]
+ return state_dict
+
+ # get the last deconv layer of the net
+ next_index = 0
+ while f"{module_prefix}deconv_layers.{next_index}.weight" in state_dict:
+ next_index += 1
+ last_index = next_index - 1
+
+ # if there are deconv layers for this module prefix (there might not be,
+ # e.g., when there are no location refinement layers in a heatmap head)
+ if last_index >= 0:
+ weight_key = f"{module_prefix}deconv_layers.{last_index}.weight"
+ bias_key = f"{module_prefix}deconv_layers.{last_index}.bias"
+
+ # for ConvTranspose2d, the weight shape is (in_channels, out_channels, ...)
+ # while it's (out_channels, in_channels, ...) for Conv2d
+ state_dict[weight_key] = state_dict[weight_key][:, conversion]
+ state_dict[bias_key] = state_dict[bias_key][conversion]
+
+ return state_dict
diff --git a/deeplabcut/pose_estimation_pytorch/models/heads/transformer.py b/deeplabcut/pose_estimation_pytorch/models/heads/transformer.py
new file mode 100644
index 0000000000..0dd014fed5
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/heads/transformer.py
@@ -0,0 +1,101 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import torch
+from einops import rearrange
+from timm.layers import trunc_normal_
+from torch import nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.criterions import BaseCriterion
+from deeplabcut.pose_estimation_pytorch.models.heads import BaseHead, HEADS
+from deeplabcut.pose_estimation_pytorch.models.predictors import BasePredictor
+from deeplabcut.pose_estimation_pytorch.models.target_generators import BaseGenerator
+
+
+@HEADS.register_module
+class TransformerHead(BaseHead):
+ """
+ Transformer Head module to predict heatmaps using a transformer-based approach
+ """
+
+ def __init__(
+ self,
+ predictor: BasePredictor,
+ target_generator: BaseGenerator,
+ criterion: BaseCriterion,
+ dim: int,
+ hidden_heatmap_dim: int,
+ heatmap_dim: int,
+ apply_multi: bool,
+ heatmap_size: tuple[int, int],
+ apply_init: bool,
+ head_stride: int,
+ ):
+ """
+ Args:
+ dim: Dimension of the input features.
+ hidden_heatmap_dim: Dimension of the hidden features in the MLP head.
+ heatmap_dim: Dimension of the output heatmaps.
+ apply_multi: If True, apply a multi-layer perceptron (MLP) with LayerNorm
+ to generate heatmaps. If False, directly apply a single linear
+ layer for heatmap prediction.
+ heatmap_size: Tuple (height, width) representing the size of the output
+ heatmaps.
+ apply_init: If True, apply weight initialization to the module's layers.
+ head_stride: The stride for the head (or neck + head pair), where positive
+ values indicate an increase in resolution while negative values a
+ decrease. Assuming that H and W are divisible by head_stride, this is
+ the value such that if a backbone outputs an encoding of shape
+ (C, H, W), the head will output heatmaps of shape:
+ (C, H * head_stride, W * head_stride) if head_stride > 0
+ (C, -H/head_stride, -W/head_stride) if head_stride < 0
+ """
+ super().__init__(head_stride, predictor, target_generator, criterion)
+ self.mlp_head = (
+ nn.Sequential(
+ nn.LayerNorm(dim * 3),
+ nn.Linear(dim * 3, hidden_heatmap_dim),
+ nn.LayerNorm(hidden_heatmap_dim),
+ nn.Linear(hidden_heatmap_dim, heatmap_dim),
+ )
+ if (dim * 3 <= hidden_heatmap_dim * 0.5 and apply_multi)
+ else nn.Sequential(nn.LayerNorm(dim * 3), nn.Linear(dim * 3, heatmap_dim))
+ )
+ self.heatmap_size = heatmap_size
+ # trunc_normal_(self.keypoint_token, std=.02)
+ if apply_init:
+ self.apply(self._init_weights)
+
+ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+ x = self.mlp_head(x)
+ x = rearrange(
+ x,
+ "b c (p1 p2) -> b c p1 p2",
+ p1=self.heatmap_size[0],
+ p2=self.heatmap_size[1],
+ )
+ return {"heatmap": x}
+
+ def _init_weights(self, m: nn.Module) -> None:
+ """
+ Custom weight initialization for linear and layer normalization layers.
+
+ Args:
+ m: module to initialize
+ """
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
diff --git a/deeplabcut/pose_estimation_pytorch/models/model.py b/deeplabcut/pose_estimation_pytorch/models/model.py
new file mode 100644
index 0000000000..64d7f6c5cf
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/model.py
@@ -0,0 +1,275 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import copy
+import logging
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.models.backbones import BACKBONES, BaseBackbone
+from deeplabcut.pose_estimation_pytorch.models.criterions import (
+ CRITERIONS,
+ LOSS_AGGREGATORS,
+)
+from deeplabcut.pose_estimation_pytorch.models.heads import BaseHead, HEADS
+from deeplabcut.pose_estimation_pytorch.models.necks import BaseNeck, NECKS
+from deeplabcut.pose_estimation_pytorch.models.predictors import PREDICTORS
+from deeplabcut.pose_estimation_pytorch.models.target_generators import (
+ TARGET_GENERATORS,
+)
+
+
+class PoseModel(nn.Module):
+ """A pose estimation model
+
+ A pose estimation model is composed of a backbone, optionally a neck, and an
+ arbitrary number of heads. Outputs are computed as follows:
+ """
+
+ def __init__(
+ self,
+ cfg: dict,
+ backbone: BaseBackbone,
+ heads: dict[str, BaseHead],
+ neck: BaseNeck | None = None,
+ ) -> None:
+ """
+ Args:
+ cfg: configuration dictionary for the model.
+ backbone: backbone network architecture.
+ heads: the heads for the model
+ neck: neck network architecture (default is None). Defaults to None.
+ """
+ super().__init__()
+ self.cfg = cfg
+ self.backbone = backbone
+ self.heads = nn.ModuleDict(heads)
+ self.neck = neck
+ self.output_features = False
+
+ self._strides = {
+ name: _model_stride(self.backbone.stride, head.stride)
+ for name, head in heads.items()
+ }
+
+ def forward(self, x: torch.Tensor) -> dict[str, dict[str, torch.Tensor]]:
+ """
+ Forward pass of the PoseModel.
+
+ Args:
+ x: input images
+
+ Returns:
+ Outputs of head groups
+ """
+ if x.dim() == 3:
+ x = x[None, :]
+ features = self.backbone(x)
+ if self.neck:
+ features = self.neck(features)
+
+ outputs = {}
+ if self.output_features:
+ outputs["backbone"] = dict(features=features)
+
+ for head_name, head in self.heads.items():
+ outputs[head_name] = head(features)
+ return outputs
+
+ def get_loss(
+ self,
+ outputs: dict[str, dict[str, torch.Tensor]],
+ targets: dict[str, dict[str, torch.Tensor]],
+ ) -> dict[str, torch.Tensor]:
+ total_losses = []
+ losses: dict[str, torch.Tensor] = {}
+ for name, head in self.heads.items():
+ head_losses = head.get_loss(outputs[name], targets[name])
+ total_losses.append(head_losses["total_loss"])
+ for k, v in head_losses.items():
+ losses[f"{name}_{k}"] = v
+
+ # TODO: Different aggregation for multi-head loss?
+ losses["total_loss"] = torch.mean(torch.stack(total_losses))
+ return losses
+
+ def get_target(
+ self,
+ outputs: dict[str, dict[str, torch.Tensor]],
+ labels: dict,
+ ) -> dict[str, dict]:
+ """Summary:
+ Get targets for model training.
+
+ Args:
+ outputs: output of each head group
+ labels: dictionary of labels
+
+ Returns:
+ targets: dict of the targets for each model head group
+ """
+ return {
+ name: head.target_generator(self._strides[name], outputs[name], labels)
+ for name, head in self.heads.items()
+ }
+
+ def get_predictions(self, outputs: dict[str, dict[str, torch.Tensor]]) -> dict:
+ """Abstract method for the forward pass of the Predictor.
+
+ Args:
+ outputs: outputs of the model heads
+
+ Returns:
+ A dictionary containing the predictions of each head group
+ """
+ predictions = {
+ name: head.predictor(self._strides[name], outputs[name])
+ for name, head in self.heads.items()
+ }
+ if self.output_features:
+ predictions["backbone"] = outputs["backbone"]
+
+ return predictions
+
+ def get_stride(self, head: str) -> int:
+ """
+ Args:
+ head: The head for which to get the total stride.
+
+ Returns:
+ The total stride for the outputs of the head.
+
+ Raises:
+ ValueError: If there is no such head.
+ """
+ return self._strides[head]
+
+ @staticmethod
+ def build(
+ cfg: dict,
+ weight_init: None | WeightInitialization = None,
+ pretrained_backbone: bool = False,
+ ) -> "PoseModel":
+ """
+ Args:
+ cfg: The configuration of the model to build.
+ weight_init: How model weights should be initialized. If None, ImageNet
+ pre-trained backbone weights are loaded from Timm.
+ pretrained_backbone: Whether to load an ImageNet-pretrained weights for
+ the backbone. This should only be set to True when building a model
+ which will be trained on a transfer learning task.
+
+ Returns:
+ the built pose model
+ """
+ cfg["backbone"]["pretrained"] = pretrained_backbone
+ backbone = BACKBONES.build(dict(cfg["backbone"]))
+
+ neck = None
+ if cfg.get("neck"):
+ neck = NECKS.build(dict(cfg["neck"]))
+
+ heads = {}
+ for name, head_cfg in cfg["heads"].items():
+ head_cfg = copy.deepcopy(head_cfg)
+ if "type" in head_cfg["criterion"]:
+ head_cfg["criterion"] = CRITERIONS.build(head_cfg["criterion"])
+ else:
+ weights = {}
+ criterions = {}
+ for loss_name, criterion_cfg in head_cfg["criterion"].items():
+ weights[loss_name] = criterion_cfg.get("weight", 1.0)
+ criterion_cfg = {
+ k: v for k, v in criterion_cfg.items() if k != "weight"
+ }
+ criterions[loss_name] = CRITERIONS.build(criterion_cfg)
+
+ aggregator_cfg = {"type": "WeightedLossAggregator", "weights": weights}
+ head_cfg["aggregator"] = LOSS_AGGREGATORS.build(aggregator_cfg)
+ head_cfg["criterion"] = criterions
+
+ head_cfg["target_generator"] = TARGET_GENERATORS.build(
+ head_cfg["target_generator"]
+ )
+ head_cfg["predictor"] = PREDICTORS.build(head_cfg["predictor"])
+ heads[name] = HEADS.build(head_cfg)
+
+ model = PoseModel(cfg=cfg, backbone=backbone, neck=neck, heads=heads)
+
+ if weight_init is not None:
+ logging.info(f"Loading pretrained model weights: {weight_init}")
+ logging.info(f"The pose model is loading from {weight_init.snapshot_path}")
+ snapshot = torch.load(weight_init.snapshot_path, map_location="cpu")
+ state_dict = snapshot["model"]
+
+ # load backbone state dict
+ model.backbone.load_state_dict(filter_state_dict(state_dict, "backbone"))
+
+ # if there's a neck, load state dict
+ if model.neck is not None:
+ model.neck.load_state_dict(filter_state_dict(state_dict, "neck"))
+
+ # load head state dicts
+ if weight_init.with_decoder:
+ all_head_state_dicts = filter_state_dict(state_dict, "heads")
+ conversion_tensor = torch.from_numpy(weight_init.conversion_array)
+ for name, head in model.heads.items():
+ head_state_dict = filter_state_dict(all_head_state_dicts, name)
+
+ # requires WeightConversionMixin
+ if not weight_init.memory_replay:
+ head_state_dict = head.convert_weights(
+ state_dict=head_state_dict,
+ module_prefix="",
+ conversion=conversion_tensor,
+ )
+
+ head.load_state_dict(head_state_dict)
+
+ return model
+
+
+def filter_state_dict(state_dict: dict, module: str) -> dict[str, torch.Tensor]:
+ """
+ Filters keys in the state dict for a module to only keep a given prefix. Removes
+ the module from the keys (e.g. for module="backbone", "backbone.stage1.weight" will
+ be converted to "stage1.weight" so the state dict can be loaded into the backbone
+ directly).
+
+ Args:
+ state_dict: the state dict
+ module: the module to keep, e.g. "backbone"
+
+ Returns:
+ the filtered state dict, with the module removed from the keys
+
+ Examples:
+ state_dict = {"backbone.conv.weight": t1, "head.conv.weight": t2}
+ filtered = filter_state_dict(state_dict, "backbone")
+ # filtered = {"conv.weight": t1}
+ model.backbone.load_state_dict(filtered)
+ """
+ return {
+ ".".join(k.split(".")[1:]): v # remove 'backbone.' from the keys
+ for k, v in state_dict.items()
+ if k.startswith(module)
+ }
+
+
+def _model_stride(backbone_stride: int | float, head_stride: int | float) -> float:
+ """Computes the model stride from a backbone and a head"""
+ if head_stride > 0:
+ return backbone_stride / head_stride
+
+ return backbone_stride * -head_stride
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/__init__.py b/deeplabcut/pose_estimation_pytorch/models/modules/__init__.py
new file mode 100644
index 0000000000..be8cb6fead
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/__init__.py
@@ -0,0 +1,24 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.modules.conv_block import (
+ AdaptBlock,
+ BasicBlock,
+ Bottleneck,
+)
+from deeplabcut.pose_estimation_pytorch.models.modules.conv_module import (
+ HighResolutionModule,
+)
+from deeplabcut.pose_estimation_pytorch.models.modules.gated_attention_unit import (
+ GatedAttentionUnit,
+)
+from deeplabcut.pose_estimation_pytorch.models.modules.norm import (
+ ScaleNorm,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/conv_block.py b/deeplabcut/pose_estimation_pytorch/models/modules/conv_block.py
new file mode 100644
index 0000000000..72816bcbcf
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/conv_block.py
@@ -0,0 +1,306 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main"""
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+import torchvision.ops as ops
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+BLOCKS = Registry("blocks", build_func=build_from_cfg)
+
+
+class BaseBlock(ABC, nn.Module):
+ """Abstract Base class for defining custom blocks.
+
+ This class defines an abstract base class for creating custom blocks used in the HigherHRNet for Human Pose Estimation.
+
+ Attributes:
+ bn_momentum: Batch normalization momentum.
+
+ Methods:
+ forward(x): Abstract method for defining the forward pass of the block.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bn_momentum = 0.1
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor):
+ """Abstract method for defining the forward pass of the block.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ pass
+
+ def _init_weights(self, pretrained: str | None):
+ """Method for initializing block weights from pretrained models.
+
+ Args:
+ pretrained: Path to pretrained model weights.
+ """
+ if pretrained:
+ self.load_state_dict(torch.load(pretrained))
+
+
+@BLOCKS.register_module
+class BasicBlock(BaseBlock):
+ """Basic Residual Block.
+
+ This class defines a basic residual block used in HigherHRNet.
+
+ Attributes:
+ expansion: The expansion factor used in the block.
+
+ Args:
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ stride: Stride value for the convolutional layers. Default is 1.
+ downsample: Downsample layer to be used in the residual connection. Default is None.
+ dilation: Dilation rate for the convolutional layers. Default is 1.
+ """
+
+ expansion: int = 1
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ downsample: nn.Module | None = None,
+ dilation: int = 1,
+ ):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the BasicBlock.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+@BLOCKS.register_module
+class Bottleneck(BaseBlock):
+ """Bottleneck Residual Block.
+
+ This class defines a bottleneck residual block used in HigherHRNet.
+
+ Attributes:
+ expansion: The expansion factor used in the block.
+
+ Args:
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ stride: Stride value for the convolutional layers. Default is 1.
+ downsample: Downsample layer to be used in the residual connection. Default is None.
+ dilation: Dilation rate for the convolutional layers. Default is 1.
+ """
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ downsample: nn.Module | None = None,
+ dilation: int = 1,
+ ):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum)
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum)
+ self.conv3 = nn.Conv2d(
+ out_channels, out_channels * self.expansion, kernel_size=1, bias=False
+ )
+ self.bn3 = nn.BatchNorm2d(
+ out_channels * self.expansion, momentum=self.bn_momentum
+ )
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the Bottleneck block.
+
+ Args:
+ x : Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+@BLOCKS.register_module
+class AdaptBlock(BaseBlock):
+ """Adaptive Residual Block with Deformable Convolution.
+
+ This class defines an adaptive residual block with deformable convolution used in HigherHRNet.
+
+ Attributes:
+ expansion: The expansion factor used in the block.
+
+ Args:
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ stride: Stride value for the convolutional layers. Default is 1.
+ downsample: Downsample layer to be used in the residual connection. Default is None.
+ dilation: Dilation rate for the convolutional layers. Default is 1.
+ deformable_groups: Number of deformable groups in the deformable convolution. Default is 1.
+ """
+
+ expansion: int = 1
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ downsample: nn.Module | None = None,
+ dilation: int = 1,
+ deformable_groups: int = 1,
+ ):
+ super(AdaptBlock, self).__init__()
+ regular_matrix = torch.tensor(
+ [[-1, -1, -1, 0, 0, 0, 1, 1, 1], [-1, 0, 1, -1, 0, 1, -1, 0, 1]]
+ )
+ self.register_buffer("regular_matrix", regular_matrix.float())
+ self.downsample = downsample
+ self.transform_matrix_conv = nn.Conv2d(in_channels, 4, 3, 1, 1, bias=True)
+ self.translation_conv = nn.Conv2d(in_channels, 2, 3, 1, 1, bias=True)
+ self.adapt_conv = ops.DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ groups=deformable_groups,
+ )
+ self.bn = nn.BatchNorm2d(out_channels, momentum=self.bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass through the AdaptBlock.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ residual = x
+
+ N, _, H, W = x.shape
+ transform_matrix = self.transform_matrix_conv(x)
+ transform_matrix = transform_matrix.permute(0, 2, 3, 1).reshape(
+ (N * H * W, 2, 2)
+ )
+ offset = torch.matmul(transform_matrix, self.regular_matrix)
+ offset = offset - self.regular_matrix
+ offset = offset.transpose(1, 2).reshape((N, H, W, 18)).permute(0, 3, 1, 2)
+
+ translation = self.translation_conv(x)
+ offset[:, 0::2, :, :] += translation[:, 0:1, :, :]
+ offset[:, 1::2, :, :] += translation[:, 1:2, :, :]
+
+ out = self.adapt_conv(x, offset)
+ out = self.bn(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/conv_module.py b/deeplabcut/pose_estimation_pytorch/models/modules/conv_module.py
new file mode 100644
index 0000000000..630eed830f
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/conv_module.py
@@ -0,0 +1,243 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""The code is based on DEKR: https://github.com/HRNet/DEKR/tree/main"""
+import logging
+from typing import List
+
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.models.modules import BasicBlock
+
+BN_MOMENTUM = 0.1
+logger = logging.getLogger(__name__)
+
+
+class HighResolutionModule(nn.Module):
+ """High-Resolution Module.
+
+ This class implements the High-Resolution Module used in HigherHRNet for Human Pose Estimation.
+
+ Args:
+ num_branches: Number of branches in the module.
+ block: The block type used in each branch of the module.
+ num_blocks: List containing the number of blocks in each branch.
+ num_inchannels: List containing the number of input channels for each branch.
+ num_channels: List containing the number of output channels for each branch.
+ fuse_method: The fusion method used in the module.
+ multi_scale_output: Whether to output multi-scale features. Default is True.
+ """
+
+ def __init__(
+ self,
+ num_branches: int,
+ block: BasicBlock,
+ num_blocks: int,
+ num_inchannels: int,
+ num_channels: int,
+ fuse_method: str,
+ multi_scale_output: bool = True,
+ ):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, block, num_blocks, num_inchannels, num_channels
+ )
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, block, num_blocks, num_channels
+ )
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(True)
+
+ def _check_branches(
+ self,
+ num_branches: int,
+ block: BasicBlock,
+ num_blocks: int,
+ num_inchannels: int,
+ num_channels: int,
+ ):
+ if num_branches != len(num_blocks):
+ error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(
+ num_branches, len(num_blocks)
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(
+ num_branches, len(num_channels)
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_inchannels):
+ error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(
+ num_branches, len(num_inchannels)
+ )
+ logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(
+ self,
+ branch_index: int,
+ block: BasicBlock,
+ num_blocks: int,
+ num_channels: int,
+ stride: int = 1,
+ ) -> nn.Sequential:
+ downsample = None
+ if (
+ stride != 1
+ or self.num_inchannels[branch_index]
+ != num_channels[branch_index] * block.expansion
+ ):
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(
+ num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM
+ ),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.num_inchannels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample,
+ )
+ )
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(self.num_inchannels[branch_index], num_channels[branch_index])
+ )
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(
+ self, num_branches: int, block: BasicBlock, num_blocks: int, num_channels: int
+ ) -> nn.ModuleList:
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self) -> nn.ModuleList:
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_inchannels[i],
+ 1,
+ 1,
+ 0,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_inchannels[i]),
+ nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
+ )
+ )
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3,
+ 2,
+ 1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_outchannels_conv3x3),
+ )
+ )
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(
+ nn.Sequential(
+ nn.Conv2d(
+ num_inchannels[j],
+ num_outchannels_conv3x3,
+ 3,
+ 2,
+ 1,
+ bias=False,
+ ),
+ nn.BatchNorm2d(num_outchannels_conv3x3),
+ nn.ReLU(True),
+ )
+ )
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self) -> int:
+ return self.num_inchannels
+
+ def forward(self, x) -> List:
+ """Forward pass through the HighResolutionModule.
+
+ Args:
+ x: List of input tensors for each branch.
+
+ Returns:
+ List of output tensors after processing through the module.
+ """
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+
+ for i in range(len(self.fuse_layers)):
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+
+ return x_fuse
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/csp.py b/deeplabcut/pose_estimation_pytorch/models/modules/csp.py
new file mode 100644
index 0000000000..3099eebeeb
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/csp.py
@@ -0,0 +1,387 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Implementation of modules needed for the CSPNeXt Backbone. Used in CSP-style models.
+
+Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For
+more information, see .
+"""
+import torch
+import torch.nn as nn
+
+
+def build_activation(activation_fn: str, *args, **kwargs) -> nn.Module:
+ if activation_fn == "SiLU":
+ return nn.SiLU(*args, **kwargs)
+ elif activation_fn == "ReLU":
+ return nn.ReLU(*args, **kwargs)
+
+ raise NotImplementedError(
+ f"Unknown `CSPNeXT` activation: {activation_fn}. Must be one of 'SiLU', 'ReLU'"
+ )
+
+
+def build_norm(norm: str, *args, **kwargs) -> nn.Module:
+ if norm == "SyncBN":
+ return nn.SyncBatchNorm(*args, **kwargs)
+ elif norm == "BN":
+ return nn.BatchNorm2d(*args, **kwargs)
+
+ raise NotImplementedError(
+ f"Unknown `CSPNeXT` norm_layer: {norm}. Must be one of 'SyncBN', 'BN'"
+ )
+
+
+class SPPBottleneck(nn.Module):
+ """Spatial pyramid pooling layer used in YOLOv3-SPP and (among others) CSPNeXt
+
+ Args:
+ in_channels: input channels to the bottleneck
+ out_channels: output channels of the bottleneck
+ kernel_sizes: kernel sizes for the pooling layers
+ norm_layer: norm layer for the bottleneck
+ activation_fn: activation function for the bottleneck
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_sizes: tuple[int, ...] = (5, 9, 13),
+ norm_layer: str | None = "SyncBN",
+ activation_fn: str | None = "SiLU",
+ ):
+ super().__init__()
+ mid_channels = in_channels // 2
+ self.conv1 = CSPConvModule(
+ in_channels,
+ mid_channels,
+ kernel_size=1,
+ stride=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+
+ self.poolings = nn.ModuleList(
+ [
+ nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
+ for ks in kernel_sizes
+ ]
+ )
+ conv2_channels = mid_channels * (len(kernel_sizes) + 1)
+ self.conv2 = CSPConvModule(
+ conv2_channels,
+ out_channels,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+
+ def forward(self, x):
+ x = self.conv1(x)
+ with torch.amp.autocast("cuda", enabled=False):
+ x = torch.cat([x] + [pooling(x) for pooling in self.poolings], dim=1)
+ x = self.conv2(x)
+ return x
+
+
+class ChannelAttention(nn.Module):
+ """Channel attention Module.
+
+ Args:
+ channels: Number of input/output channels of the layer.
+ """
+
+ def __init__(self, channels: int) -> None:
+ super().__init__()
+ self.global_avgpool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
+ self.act = nn.Hardsigmoid(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ with torch.amp.autocast("cuda", enabled=False):
+ out = self.global_avgpool(x)
+ out = self.fc(out)
+ out = self.act(out)
+ return x * out
+
+
+class CSPConvModule(nn.Module):
+ """Configurable convolution module used for CSPNeXT.
+
+ Applies sequentially
+ - a convolution
+ - (optional) a norm layer
+ - (optional) an activation function
+
+ Args:
+ in_channels: Input channels of the convolution.
+ out_channels: Output channels of the convolution.
+ kernel_size: Convolution kernel size.
+ stride: Convolution stride.
+ padding: Convolution padding.
+ dilation: Convolution dilation.
+ groups: Number of blocked connections from input to output channels.
+ norm_layer: Norm layer to apply, if any.
+ activation_fn: Activation function to apply, if any.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int | tuple[int, int],
+ stride: int | tuple[int, int] = 1,
+ padding: int | tuple[int, int] = 0,
+ dilation: int | tuple[int, int] = 1,
+ groups: int = 1,
+ norm_layer: str | None = None,
+ activation_fn: str | None = "ReLU",
+ ):
+ super().__init__()
+
+ self.with_activation = activation_fn is not None
+ self.with_bias = norm_layer is None
+ self.with_norm = norm_layer is not None
+
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=self.with_bias,
+ )
+ self.activate = None
+ self.norm = None
+
+ if self.with_norm:
+ self.norm = build_norm(norm_layer, out_channels)
+
+ if self.with_activation:
+ # Careful when adding activation functions: some should not be in-place
+ self.activate = build_activation(activation_fn, inplace=True)
+
+ self._init_weights()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ if self.with_norm:
+ x = self.norm(x)
+ if self.with_activation:
+ x = self.activate(x)
+ return x
+
+ def _init_weights(self) -> None:
+ """Same init as in convolutions"""
+ nn.init.kaiming_normal_(self.conv.weight, a=0, nonlinearity="relu")
+ if self.with_bias:
+ nn.init.constant_(self.conv.bias, 0)
+
+ if self.with_norm:
+ nn.init.constant_(self.norm.weight, 1)
+ nn.init.constant_(self.norm.bias, 0)
+
+
+class DepthwiseSeparableConv(nn.Module):
+ """Depth-wise separable convolution module used for CSPNeXT.
+
+ Applies sequentially
+ - a depth-wise conv
+ - a point-wise conv
+
+ Args:
+ in_channels: Input channels of the convolution.
+ out_channels: Output channels of the convolution.
+ kernel_size: Convolution kernel size.
+ stride: Convolution stride.
+ padding: Convolution padding.
+ dilation: Convolution dilation.
+ norm_layer: Norm layer to apply, if any.
+ activation_fn: Activation function to apply, if any.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int | tuple[int, int],
+ stride: int | tuple[int, int] = 1,
+ padding: int | tuple[int, int] = 0,
+ dilation: int | tuple[int, int] = 1,
+ norm_layer: str | None = None,
+ activation_fn: str | None = "ReLU",
+ ):
+ super().__init__()
+
+ # depthwise convolution
+ self.depthwise_conv = CSPConvModule(
+ in_channels,
+ in_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=in_channels,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+
+ self.pointwise_conv = CSPConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.depthwise_conv(x)
+ x = self.pointwise_conv(x)
+ return x
+
+
+class CSPNeXtBlock(nn.Module):
+ """Basic bottleneck block used in CSPNeXt.
+
+ Args:
+ in_channels: input channels for the block
+ out_channels: output channels for the block
+ expansion: expansion factor for the hidden channels
+ add_identity: add a skip-connection to the block
+ kernel_size: kernel size for the DepthwiseSeparableConv
+ norm_layer: Norm layer to apply, if any.
+ activation_fn: Activation function to apply, if any.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion: float = 0.5,
+ add_identity: bool = True,
+ kernel_size: int = 5,
+ norm_layer: str | None = None,
+ activation_fn: str | None = "ReLU",
+ ) -> None:
+ super().__init__()
+ hidden_channels = int(out_channels * expansion)
+ self.conv1 = CSPConvModule(
+ in_channels,
+ hidden_channels,
+ 3,
+ stride=1,
+ padding=1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ self.conv2 = DepthwiseSeparableConv(
+ hidden_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ self.add_identity = add_identity and in_channels == out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ identity = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+
+ if self.add_identity:
+ return out + identity
+ else:
+ return out
+
+
+class CSPLayer(nn.Module):
+ """Cross Stage Partial Layer.
+
+ Args:
+ in_channels: input channels for the layer
+ out_channels: output channels for the block
+ expand_ratio: expansion factor for the mid-channels
+ num_blocks: the number of blocks to use
+ add_identity: add a skip-connection to the blocks
+ channel_attention: whether to apply channel attention
+ norm_layer: Norm layer to apply, if any.
+ activation_fn: Activation function to apply, if any.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 0.5,
+ num_blocks: int = 1,
+ add_identity: bool = True,
+ channel_attention: bool = False,
+ norm_layer: str | None = None,
+ activation_fn: str | None = "ReLU",
+ ) -> None:
+ super().__init__()
+ mid_channels = int(out_channels * expand_ratio)
+ self.channel_attention = channel_attention
+ self.main_conv = CSPConvModule(
+ in_channels,
+ mid_channels,
+ 1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ self.short_conv = CSPConvModule(
+ in_channels,
+ mid_channels,
+ 1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ self.final_conv = CSPConvModule(
+ 2 * mid_channels,
+ out_channels,
+ 1,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+
+ self.blocks = nn.Sequential(
+ *[
+ CSPNeXtBlock(
+ mid_channels,
+ mid_channels,
+ 1.0,
+ add_identity,
+ norm_layer=norm_layer,
+ activation_fn=activation_fn,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+ if channel_attention:
+ self.attention = ChannelAttention(2 * mid_channels)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward function."""
+ x_short = self.short_conv(x)
+
+ x_main = self.main_conv(x)
+ x_main = self.blocks(x_main)
+
+ x_final = torch.cat((x_main, x_short), dim=1)
+
+ if self.channel_attention:
+ x_final = self.attention(x_final)
+ return self.final_conv(x_final)
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/gated_attention_unit.py b/deeplabcut/pose_estimation_pytorch/models/modules/gated_attention_unit.py
new file mode 100644
index 0000000000..fd12ee43d8
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/gated_attention_unit.py
@@ -0,0 +1,237 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Gated Attention Unit
+
+Based on the building blocks used for the ``mmdetection`` CSPNeXt implementation. For
+more information, see .
+"""
+from __future__ import annotations
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import timm.layers as timm_layers
+
+from deeplabcut.pose_estimation_pytorch.models.modules.norm import ScaleNorm
+
+
+def rope(x, dim):
+ """Applies Rotary Position Embedding to input tensor."""
+ shape = x.shape
+ if isinstance(dim, int):
+ dim = [dim]
+
+ spatial_shape = [shape[i] for i in dim]
+ total_len = 1
+ for i in spatial_shape:
+ total_len *= i
+
+ position = torch.reshape(
+ torch.arange(total_len, dtype=torch.int, device=x.device), spatial_shape
+ )
+
+ for i in range(dim[-1] + 1, len(shape) - 1, 1):
+ position = torch.unsqueeze(position, dim=-1)
+
+ half_size = shape[-1] // 2
+ freq_seq = -torch.arange(half_size, dtype=torch.int, device=x.device) / float(
+ half_size
+ )
+ inv_freq = 10000**-freq_seq
+
+ sinusoid = position[..., None] * inv_freq[None, None, :]
+
+ sin = torch.sin(sinusoid)
+ cos = torch.cos(sinusoid)
+ x1, x2 = torch.chunk(x, 2, dim=-1)
+
+ return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
+
+
+class Scale(nn.Module):
+ """Scale vector by element multiplications.
+
+ Args:
+ dim: The dimension of the scale vector.
+ init_value: The initial value of the scale vector.
+ trainable: Whether the scale vector is trainable.
+ """
+
+ def __init__(self, dim, init_value=1.0, trainable=True):
+ super().__init__()
+ self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
+
+ def forward(self, x):
+ return x * self.scale
+
+
+class GatedAttentionUnit(nn.Module):
+ """Gated Attention Unit (GAU) in RTMBlock"""
+
+ def __init__(
+ self,
+ num_token,
+ in_token_dims,
+ out_token_dims,
+ expansion_factor=2,
+ s=128,
+ eps=1e-5,
+ dropout_rate=0.0,
+ drop_path=0.0,
+ attn_type="self-attn",
+ act_fn="SiLU",
+ bias=False,
+ use_rel_bias=True,
+ pos_enc=False,
+ ):
+ super(GatedAttentionUnit, self).__init__()
+ self.s = s
+ self.num_token = num_token
+ self.use_rel_bias = use_rel_bias
+ self.attn_type = attn_type
+ self.pos_enc = pos_enc
+
+ if drop_path > 0.0:
+ self.drop_path = timm_layers.DropPath(drop_path)
+ else:
+ self.drop_path = nn.Identity()
+
+ self.e = int(in_token_dims * expansion_factor)
+ if use_rel_bias:
+ if attn_type == "self-attn":
+ self.w = nn.Parameter(
+ torch.rand([2 * num_token - 1], dtype=torch.float)
+ )
+ else:
+ self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float))
+ self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float))
+ self.o = nn.Linear(self.e, out_token_dims, bias=bias)
+
+ if attn_type == "self-attn":
+ self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias)
+ self.gamma = nn.Parameter(torch.rand((2, self.s)))
+ self.beta = nn.Parameter(torch.rand((2, self.s)))
+ else:
+ self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias)
+ self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias)
+ self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias)
+ nn.init.xavier_uniform_(self.k_fc.weight)
+ nn.init.xavier_uniform_(self.v_fc.weight)
+
+ self.ln = ScaleNorm(in_token_dims, eps=eps)
+
+ nn.init.xavier_uniform_(self.uv.weight)
+
+ if act_fn == "SiLU" or act_fn == nn.SiLU:
+ self.act_fn = nn.SiLU(True)
+ elif act_fn == "ReLU" or act_fn == nn.ReLU:
+ self.act_fn = nn.ReLU(True)
+ else:
+ raise NotImplementedError
+
+ if in_token_dims == out_token_dims:
+ self.shortcut = True
+ self.res_scale = Scale(in_token_dims)
+ else:
+ self.shortcut = False
+
+ self.sqrt_s = math.sqrt(s)
+
+ self.dropout_rate = dropout_rate
+
+ if dropout_rate > 0.0:
+ self.dropout = nn.Dropout(dropout_rate)
+
+ def rel_pos_bias(self, seq_len, k_len=None):
+ """Add relative position bias."""
+
+ if self.attn_type == "self-attn":
+ t = F.pad(self.w[: 2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
+ t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
+ r = (2 * seq_len - 1) // 2
+ t = t[..., r:-r]
+ else:
+ a = rope(self.a.repeat(seq_len, 1), dim=0)
+ b = rope(self.b.repeat(k_len, 1), dim=0)
+ t = torch.bmm(a, b.permute(0, 2, 1))
+ return t
+
+ def _forward(self, inputs):
+ """GAU Forward function."""
+
+ if self.attn_type == "self-attn":
+ x = inputs
+ else:
+ x, k, v = inputs
+
+ x = self.ln(x)
+
+ # [B, K, in_token_dims] -> [B, K, e + e + s]
+ uv = self.uv(x)
+ uv = self.act_fn(uv)
+
+ if self.attn_type == "self-attn":
+ # [B, K, e + e + s] -> [B, K, e], [B, K, e], [B, K, s]
+ u, v, base = torch.split(uv, [self.e, self.e, self.s], dim=2)
+ # [B, K, 1, s] * [1, 1, 2, s] + [2, s] -> [B, K, 2, s]
+ base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta
+
+ if self.pos_enc:
+ base = rope(base, dim=1)
+ # [B, K, 2, s] -> [B, K, s], [B, K, s]
+ q, k = torch.unbind(base, dim=2)
+
+ else:
+ # [B, K, e + s] -> [B, K, e], [B, K, s]
+ u, q = torch.split(uv, [self.e, self.s], dim=2)
+
+ k = self.k_fc(k) # -> [B, K, s]
+ v = self.v_fc(v) # -> [B, K, e]
+
+ if self.pos_enc:
+ q = rope(q, 1)
+ k = rope(k, 1)
+
+ # [B, K, s].permute() -> [B, s, K]
+ # [B, K, s] x [B, s, K] -> [B, K, K]
+ qk = torch.bmm(q, k.permute(0, 2, 1))
+
+ if self.use_rel_bias:
+ if self.attn_type == "self-attn":
+ bias = self.rel_pos_bias(q.size(1))
+ else:
+ bias = self.rel_pos_bias(q.size(1), k.size(1))
+ qk += bias[:, : q.size(1), : k.size(1)]
+ # [B, K, K]
+ kernel = torch.square(F.relu(qk / self.sqrt_s))
+
+ if self.dropout_rate > 0.0:
+ kernel = self.dropout(kernel)
+ # [B, K, K] x [B, K, e] -> [B, K, e]
+ x = u * torch.bmm(kernel, v)
+
+ # [B, K, e] -> [B, K, out_token_dims]
+ x = self.o(x)
+
+ return x
+
+ def forward(self, x):
+ if self.shortcut:
+ if self.attn_type == "cross-attn":
+ res_shortcut = x[0]
+ else:
+ res_shortcut = x
+ main_branch = self.drop_path(self._forward(x))
+ return self.res_scale(res_shortcut) + main_branch
+ else:
+ return self.drop_path(self._forward(x))
diff --git a/deeplabcut/pose_estimation_pytorch/models/modules/norm.py b/deeplabcut/pose_estimation_pytorch/models/modules/norm.py
new file mode 100644
index 0000000000..1cbc0f4f3a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/modules/norm.py
@@ -0,0 +1,41 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Normalization layers"""
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+
+class ScaleNorm(nn.Module):
+ """Implementation of ScaleNorm
+
+ ScaleNorm was introduced in "Transformers without Tears: Improving the Normalization
+ of Self-Attention".
+
+ Code based on the `mmpose` implementation. See https://github.com/open-mmlab/mmpose
+ for more details.
+
+ Args:
+ dim: The dimension of the scale vector.
+ eps: The minimum value in clamp.
+ """
+
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1))
+
+ def forward(self, x):
+ norm = torch.linalg.norm(x, dim=-1, keepdim=True)
+ norm = norm * self.scale
+ return x / norm.clamp(min=self.eps) * self.g
diff --git a/deeplabcut/pose_estimation_pytorch/models/necks/__init__.py b/deeplabcut/pose_estimation_pytorch/models/necks/__init__.py
new file mode 100644
index 0000000000..5b3823ab6b
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/necks/__init__.py
@@ -0,0 +1,12 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.necks.base import BaseNeck, NECKS
+from deeplabcut.pose_estimation_pytorch.models.necks.transformer import Transformer
diff --git a/deeplabcut/pose_estimation_pytorch/models/necks/base.py b/deeplabcut/pose_estimation_pytorch/models/necks/base.py
new file mode 100644
index 0000000000..201456a5d4
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/necks/base.py
@@ -0,0 +1,48 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from abc import ABC, abstractmethod
+
+import torch
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+NECKS = Registry("necks", build_func=build_from_cfg)
+
+
+class BaseNeck(ABC, torch.nn.Module):
+ """Base Neck class for pose estimation"""
+
+ def __init__(self):
+ super().__init__()
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor):
+ """Abstract method for the forward pass through the Neck.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ pass
+
+ def _init_weights(self, pretrained: str):
+ """Initialize the Neck with pretrained weights.
+
+ Args:
+ pretrained: Path to the pretrained weights.
+
+ Returns:
+ None
+ """
+ if pretrained:
+ self.model.load_state_dict(torch.load(pretrained))
diff --git a/deeplabcut/pose_estimation_pytorch/models/necks/layers.py b/deeplabcut/pose_estimation_pytorch/models/necks/layers.py
new file mode 100644
index 0000000000..6c3ce7d45d
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/necks/layers.py
@@ -0,0 +1,287 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+
+class Residual(torch.nn.Module):
+ """Residual block module.
+
+ This module implements a residual block for the transformer layers.
+
+ Attributes:
+ fn: The function to apply in the residual block.
+ """
+
+ def __init__(self, fn: torch.nn.Module):
+ """Initialize the Residual block.
+
+ Args:
+ fn: The function to apply in the residual block.
+ """
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x: torch.Tensor, **kwargs):
+ """Forward pass through the Residual block.
+
+ Args:
+ x: Input tensor.
+ **kwargs: Additional keyword arguments for the function.
+
+ Returns:
+ Output tensor.
+ """
+ return self.fn(x, **kwargs) + x
+
+
+class PreNorm(torch.nn.Module):
+ """PreNorm block module.
+
+ This module implements pre-normalization for the transformer layers.
+
+ Attributes:
+ dim: Dimension of the input tensor.
+ fn: The function to apply after normalization.
+ fusion_factor: Fusion factor for layer normalization.
+ Defaults to 1.
+ """
+
+ def __init__(self, dim: int, fn: torch.nn.Module, fusion_factor: int = 1):
+ """Initialize the PreNorm block.
+
+ Args:
+ dim: Dimension of the input tensor.
+ fn: The function to apply after normalization.
+ fusion_factor: Fusion factor for layer normalization.
+ Defaults to 1.
+ """
+ super().__init__()
+ self.norm = torch.nn.LayerNorm(dim * fusion_factor)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ """Forward pass through the PreNorm block.
+
+ Args:
+ x: Input tensor.
+ **kwargs: Additional keyword arguments for the function.
+
+ Returns:
+ Output tensor.
+ """
+ return self.fn(self.norm(x), **kwargs)
+
+
+class FeedForward(torch.nn.Module):
+ """FeedForward block module.
+
+ This module implements the feedforward layer in the transformer layers.
+
+ Attributes:
+ dim: Dimension of the input tensor.
+ hidden_dim: Dimension of the hidden layer.
+ dropout: Dropout rate. Defaults to 0.0.
+ """
+
+ def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0):
+ """Initialize the FeedForward block.
+
+ Args:
+ dim: Dimension of the input tensor.
+ hidden_dim: Dimension of the hidden layer.
+ dropout: Dropout rate. Defaults to 0.0.
+ """
+ super().__init__()
+ self.net = torch.nn.Sequential(
+ torch.nn.Linear(dim, hidden_dim),
+ torch.nn.GELU(),
+ torch.nn.Dropout(dropout),
+ torch.nn.Linear(hidden_dim, dim),
+ torch.nn.Dropout(dropout),
+ )
+
+ def forward(self, x: torch.Tensor):
+ """Forward pass through the FeedForward block.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor.
+ """
+ return self.net(x)
+
+
+class Attention(torch.nn.Module):
+ """Attention block module.
+
+ This module implements the attention mechanism in the transformer layers.
+
+ Attributes:
+ dim: Dimension of the input tensor.
+ heads: Number of attention heads. Defaults to 8.
+ dropout: Dropout rate. Defaults to 0.0.
+ num_keypoints: Number of keypoints. Defaults to None.
+ scale_with_head: Scale attention with the number of heads.
+ Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dropout: float = 0.0,
+ num_keypoints: int = None,
+ scale_with_head: bool = False,
+ ):
+ """Initialize the Attention block.
+
+ Args:
+ dim: Dimension of the input tensor.
+ heads: Number of attention heads. Defaults to 8.
+ dropout: Dropout rate. Defaults to 0.0.
+ num_keypoints: Number of keypoints. Defaults to None.
+ scale_with_head: Scale attention with the number of heads.
+ Defaults to False.
+ """
+ super().__init__()
+ self.heads = heads
+ self.scale = (dim // heads) ** -0.5 if scale_with_head else dim ** -0.5
+
+ self.to_qkv = torch.nn.Linear(dim, dim * 3, bias=False)
+ self.to_out = torch.nn.Sequential(
+ torch.nn.Linear(dim, dim), torch.nn.Dropout(dropout)
+ )
+ self.num_keypoints = num_keypoints
+
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
+ """Forward pass through the Attention block.
+
+ Args:
+ x: Input tensor.
+ mask: Attention mask. Defaults to None.
+
+ Returns:
+ Output tensor.
+ """
+ b, n, _, h = *x.shape, self.heads
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
+
+ dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
+ mask_value = -torch.finfo(dots.dtype).max
+
+ if mask is not None:
+ mask = F.pad(mask.flatten(1), (1, 0), value=True)
+ assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions"
+ mask = mask[:, None, :] * mask[:, :, None]
+ dots.masked_fill_(~mask, mask_value)
+ del mask
+
+ attn = dots.softmax(dim=-1)
+
+ out = torch.einsum("bhij,bhjd->bhid", attn, v)
+
+ out = rearrange(out, "b h n d -> b n (h d)")
+ out = self.to_out(out)
+ return out
+
+
+class TransformerLayer(torch.nn.Module):
+ """TransformerLayer block module.
+
+ This module implements the Transformer layer in the transformer model.
+
+ Attributes:
+ dim: Dimension of the input tensor.
+ depth: Depth of the transformer layer.
+ heads: Number of attention heads.
+ mlp_dim: Dimension of the MLP layer.
+ dropout: Dropout rate.
+ num_keypoints: Number of keypoints. Defaults to None.
+ all_attn: Apply attention to all keypoints.
+ Defaults to False.
+ scale_with_head: Scale attention with the number of heads.
+ Defaults to False.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int,
+ dropout: float,
+ num_keypoints: int = None,
+ all_attn: bool = False,
+ scale_with_head: bool = False,
+ ):
+ """Initialize the TransformerLayer block.
+
+ Args:
+ dim: Dimension of the input tensor.
+ depth: Depth of the transformer layer.
+ heads: Number of attention heads.
+ mlp_dim: Dimension of the MLP layer.
+ dropout: Dropout rate.
+ num_keypoints: Number of keypoints. Defaults to None.
+ all_attn: Apply attention to all keypoints. Defaults to False.
+ scale_with_head: Scale attention with the number of heads. Defaults to False.
+ """
+ super().__init__()
+ self.layers = torch.nn.ModuleList([])
+ self.all_attn = all_attn
+ self.num_keypoints = num_keypoints
+ for _ in range(depth):
+ self.layers.append(
+ torch.nn.ModuleList(
+ [
+ Residual(
+ PreNorm(
+ dim,
+ Attention(
+ dim,
+ heads=heads,
+ dropout=dropout,
+ num_keypoints=num_keypoints,
+ scale_with_head=scale_with_head,
+ ),
+ )
+ ),
+ Residual(
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
+ ),
+ ]
+ )
+ )
+
+ def forward(
+ self, x: torch.Tensor, mask: torch.Tensor = None, pos: torch.Tensor = None
+ ):
+ """Forward pass through the TransformerLayer block.
+
+ Args:
+ x: Input tensor.
+ mask: Attention mask. Defaults to None.
+ pos: Positional encoding. Defaults to None.
+
+ Returns:
+ Output tensor.
+ """
+ for idx, (attn, ff) in enumerate(self.layers):
+ if idx > 0 and self.all_attn:
+ x[:, self.num_keypoints :] += pos
+ x = attn(x, mask=mask)
+ x = ff(x)
+ return x
diff --git a/deeplabcut/pose_estimation_pytorch/models/necks/transformer.py b/deeplabcut/pose_estimation_pytorch/models/necks/transformer.py
new file mode 100644
index 0000000000..7b34d49975
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/necks/transformer.py
@@ -0,0 +1,276 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from typing import Tuple
+
+import torch
+from einops import rearrange, repeat
+from timm.layers import trunc_normal_
+
+from deeplabcut.pose_estimation_pytorch.models.necks.base import BaseNeck, NECKS
+from deeplabcut.pose_estimation_pytorch.models.necks.layers import TransformerLayer
+from deeplabcut.pose_estimation_pytorch.models.necks.utils import (
+ make_sine_position_embedding,
+)
+
+MIN_NUM_PATCHES = 16
+BN_MOMENTUM = 0.1
+
+
+@NECKS.register_module
+class Transformer(BaseNeck):
+ """Transformer Neck for pose estimation.
+ title={TokenPose: Learning Keypoint Tokens for Human Pose Estimation},
+ author={Yanjie Li and Shoukui Zhang and Zhicheng Wang and Sen Yang and Wankou Yang and Shu-Tao Xia and Erjin Zhou},
+ booktitle={IEEE/CVF International Conference on Computer Vision (ICCV)},
+ year={2021}
+
+ Args:
+ feature_size: Size of the input feature map (height, width).
+ patch_size: Size of each patch used in the transformer.
+ num_keypoints: Number of keypoints in the pose estimation task.
+ dim: Dimension of the transformer.
+ depth: Number of transformer layers.
+ heads: Number of self-attention heads in the transformer.
+ mlp_dim: Dimension of the MLP used in the transformer.
+ Defaults to 3.
+ apply_init: Whether to apply weight initialization.
+ Defaults to False.
+ heatmap_size: Size of the heatmap. Defaults to [64, 64].
+ channels: Number of channels in each patch. Defaults to 32.
+ dropout: Dropout rate for embeddings. Defaults to 0.0.
+ emb_dropout: Dropout rate for transformer layers.
+ Defaults to 0.0.
+ pos_embedding_type: Type of positional embedding.
+ Either 'sine-full', 'sine', or 'learnable'.
+ Defaults to "sine-full".
+
+ Examples:
+ # Creating a Transformer neck with sine positional embedding
+ transformer = Transformer(
+ feature_size=(128, 128),
+ patch_size=(16, 16),
+ num_keypoints=17,
+ dim=256,
+ depth=6,
+ heads=8,
+ pos_embedding_type="sine"
+ )
+
+ # Creating a Transformer neck with learnable positional embedding
+ transformer = Transformer(
+ feature_size=(256, 256),
+ patch_size=(32, 32),
+ num_keypoints=17,
+ dim=512,
+ depth=12,
+ heads=16,
+ pos_embedding_type="learnable"
+ )
+ """
+
+ def __init__(
+ self,
+ *,
+ feature_size: Tuple[int, int],
+ patch_size: Tuple[int, int],
+ num_keypoints: int,
+ dim: int,
+ depth: int,
+ heads: int,
+ mlp_dim: int = 3,
+ apply_init: bool = False,
+ heatmap_size: Tuple[int, int] = (64, 64),
+ channels: int = 32,
+ dropout: float = 0.0,
+ emb_dropout: float = 0.0,
+ pos_embedding_type: str = "sine-full"
+ ):
+ super().__init__()
+
+ num_patches = (feature_size[0] // (patch_size[0])) * (
+ feature_size[1] // (patch_size[1])
+ )
+ patch_dim = channels * patch_size[0] * patch_size[1]
+
+ self.inplanes = 64
+ self.patch_size = patch_size
+ self.heatmap_size = heatmap_size
+ self.num_keypoints = num_keypoints
+ self.num_patches = num_patches
+ self.pos_embedding_type = pos_embedding_type
+ self.all_attn = self.pos_embedding_type == "sine-full"
+
+ self.keypoint_token = torch.nn.Parameter(
+ torch.zeros(1, self.num_keypoints, dim)
+ )
+ h, w = (
+ feature_size[0] // (self.patch_size[0]),
+ feature_size[1] // (self.patch_size[1]),
+ )
+
+ self._make_position_embedding(w, h, dim, pos_embedding_type)
+
+ self.patch_to_embedding = torch.nn.Linear(patch_dim, dim)
+ self.dropout = torch.nn.Dropout(emb_dropout)
+
+ self.transformer1 = TransformerLayer(
+ dim,
+ depth,
+ heads,
+ mlp_dim,
+ dropout,
+ num_keypoints=num_keypoints,
+ scale_with_head=True,
+ )
+ self.transformer2 = TransformerLayer(
+ dim,
+ depth,
+ heads,
+ mlp_dim,
+ dropout,
+ num_keypoints=num_keypoints,
+ all_attn=self.all_attn,
+ scale_with_head=True,
+ )
+ self.transformer3 = TransformerLayer(
+ dim,
+ depth,
+ heads,
+ mlp_dim,
+ dropout,
+ num_keypoints=num_keypoints,
+ all_attn=self.all_attn,
+ scale_with_head=True,
+ )
+
+ self.to_keypoint_token = torch.nn.Identity()
+
+ if apply_init:
+ self.apply(self._init_weights)
+
+ def _make_position_embedding(
+ self, w: int, h: int, d_model: int, pe_type="learnable"
+ ):
+ """Create position embeddings for the transformer.
+
+ Args:
+ w: Width of the input feature map.
+ h: Height of the input feature map.
+ d_model: Dimension of the transformer encoder.
+ pe_type: Type of position embeddings.
+ Either "learnable" or "sine". Defaults to "learnable".
+ """
+ with torch.no_grad():
+ self.pe_h = h
+ self.pe_w = w
+ length = h * w
+ if pe_type != "learnable":
+ self.pos_embedding = torch.nn.Parameter(
+ make_sine_position_embedding(h, w, d_model), requires_grad=False
+ )
+ else:
+ self.pos_embedding = torch.nn.Parameter(
+ torch.zeros(1, self.num_patches + self.num_keypoints, d_model)
+ )
+
+ def _make_layer(
+ self, block: torch.nn.Module, planes: int, blocks: int, stride: int = 1
+ ) -> torch.nn.Sequential:
+ """Create a layer of the transformer encoder.
+
+ Args:
+ block: The basic building block of the layer.
+ planes: Number of planes in the layer.
+ blocks: Number of blocks in the layer.
+ stride: Stride value. Defaults to 1.
+
+ Returns:
+ The layer of the transformer encoder.
+ """
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = torch.nn.Sequential(
+ torch.nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ torch.nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return torch.nn.Sequential(*layers)
+
+ def _init_weights(self, m: torch.nn.Module):
+ """Initialize the weights of the model.
+
+ Args:
+ m: A module of the model.
+ """
+ print("Initialization...")
+ if isinstance(m, torch.nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, torch.nn.Linear) and m.bias is not None:
+ torch.nn.init.constant_(m.bias, 0)
+ elif isinstance(m, torch.nn.LayerNorm):
+ torch.nn.init.constant_(m.bias, 0)
+ torch.nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, feature: torch.Tensor, mask=None) -> torch.Tensor:
+ """Forward pass through the Transformer neck.
+
+ Args:
+ feature: Input feature map.
+ mask: Mask to apply to the transformer.
+ Defaults to None.
+
+ Returns:
+ Output tensor from the transformer neck.
+
+ Examples:
+ # Assuming feature is a torch.Tensor of shape (batch_size, channels, height, width)
+ output = transformer(feature)
+ """
+ p = self.patch_size
+
+ x = rearrange(
+ feature, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p[0], p2=p[1]
+ )
+ x = self.patch_to_embedding(x)
+
+ b, n, _ = x.shape
+
+ keypoint_tokens = repeat(self.keypoint_token, "() n d -> b n d", b=b)
+ if self.pos_embedding_type in ["sine", "sine-full"]:
+ x += self.pos_embedding[:, :n]
+ x = torch.cat((keypoint_tokens, x), dim=1)
+ else:
+ x = torch.cat((keypoint_tokens, x), dim=1)
+ x += self.pos_embedding[:, : (n + self.num_keypoints)]
+ x = self.dropout(x)
+
+ x1 = self.transformer1(x, mask, self.pos_embedding)
+ x2 = self.transformer2(x1, mask, self.pos_embedding)
+ x3 = self.transformer3(x2, mask, self.pos_embedding)
+
+ x1_out = self.to_keypoint_token(x1[:, 0 : self.num_keypoints])
+ x2_out = self.to_keypoint_token(x2[:, 0 : self.num_keypoints])
+ x3_out = self.to_keypoint_token(x3[:, 0 : self.num_keypoints])
+
+ x = torch.cat((x1_out, x2_out, x3_out), dim=2)
+ return x
diff --git a/deeplabcut/pose_estimation_pytorch/models/necks/utils.py b/deeplabcut/pose_estimation_pytorch/models/necks/utils.py
new file mode 100644
index 0000000000..028078b8ab
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/necks/utils.py
@@ -0,0 +1,60 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import math
+
+import torch
+
+
+def make_sine_position_embedding(
+ h: int, w: int, d_model: int, temperature: int = 10000, scale: float = 2 * math.pi
+) -> torch.Tensor:
+ """Generate sine position embeddings for a given height, width, and model dimension.
+
+ Args:
+ h: Height of the embedding.
+ w: Width of the embedding.
+ d_model: Dimension of the model.
+ temperature: Temperature parameter for position embedding calculation.
+ Defaults to 10000.
+ scale: Scaling factor for position embedding. Defaults to 2 * math.pi.
+
+ Returns:
+ Sine position embeddings with shape (batch_size, d_model, h * w).
+
+ Example:
+ >>> h, w, d_model = 10, 20, 512
+ >>> pos_emb = make_sine_position_embedding(h, w, d_model)
+ >>> print(pos_emb.shape) # Output: torch.Size([1, 512, 200])
+ """
+ area = torch.ones(1, h, w)
+ y_embed = area.cumsum(1, dtype=torch.float32)
+ x_embed = area.cumsum(2, dtype=torch.float32)
+ one_direction_feats = d_model // 2
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * scale
+
+ dim_t = torch.arange(one_direction_feats, dtype=torch.float32)
+ dim_t = temperature ** (2 * (dim_t // 2) / one_direction_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ pos = pos.flatten(2).permute(0, 2, 1)
+
+ return pos
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/__init__.py b/deeplabcut/pose_estimation_pytorch/models/predictors/__init__.py
new file mode 100644
index 0000000000..59b10222a9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/__init__.py
@@ -0,0 +1,29 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.predictors.base import (
+ PREDICTORS,
+ BasePredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors.dekr_predictor import (
+ DEKRPredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors.identity_predictor import (
+ IdentityPredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors.paf_predictor import (
+ PartAffinityFieldPredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors.sim_cc import (
+ SimCCPredictor,
+)
+from deeplabcut.pose_estimation_pytorch.models.predictors.single_predictor import (
+ HeatmapPredictor,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/base.py b/deeplabcut/pose_estimation_pytorch/models/predictors/base.py
new file mode 100644
index 0000000000..dc9b38aab6
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/base.py
@@ -0,0 +1,66 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+PREDICTORS = Registry("predictors", build_func=build_from_cfg)
+
+
+class BasePredictor(ABC, nn.Module):
+ """The base Predictor class.
+
+ This class is an abstract base class (ABC) for defining predictors used in the DeepLabCut Toolbox.
+ All predictor classes should inherit from this base class and implement the forward method.
+ Regresses keypoint coordinates from a models output maps
+
+ Attributes:
+ num_animals: Number of animals in the project. Should be set in subclasses.
+
+ Example:
+ # Create a subclass that inherits from BasePredictor and implements the forward method.
+ class MyPredictor(BasePredictor):
+ def __init__(self, num_animals):
+ super().__init__()
+ self.num_animals = num_animals
+
+ def forward(self, outputs):
+ # Implement the forward pass of your custom predictor here.
+ pass
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.num_animals = None
+
+ @abstractmethod
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ """Abstract method for the forward pass of the Predictor.
+
+ Args:
+ stride: the stride of the model
+ outputs: outputs of the model heads
+
+ Returns:
+ A dictionary containing a "poses" key with the output tensor as value, and
+ optionally a "unique_bodyparts" with the unique bodyparts tensor as value.
+
+ Raises:
+ NotImplementedError: This method must be implemented in subclasses.
+ """
+ pass
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/dekr_predictor.py b/deeplabcut/pose_estimation_pytorch/models/predictors/dekr_predictor.py
new file mode 100644
index 0000000000..b7f2b57820
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/dekr_predictor.py
@@ -0,0 +1,408 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+from __future__ import annotations
+
+import torch
+import torch.nn.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.predictors import (
+ BasePredictor,
+ PREDICTORS,
+)
+
+
+@PREDICTORS.register_module
+class DEKRPredictor(BasePredictor):
+ """DEKR Predictor class for multi-animal pose estimation.
+
+ This class regresses keypoints and assembles them (if multianimal project)
+ from the output of DEKR (Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression).
+ Based on:
+ Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression
+ Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang
+ CVPR
+ 2021
+ Code based on:
+ https://github.com/HRNet/DEKR
+
+ Args:
+ num_animals (int): Number of animals in the project.
+ detection_threshold (float, optional): Threshold for detection. Defaults to 0.01.
+ apply_sigmoid (bool, optional): Apply sigmoid to heatmaps. Defaults to True.
+ use_heatmap (bool, optional): Use heatmap to refine keypoint predictions. Defaults to True.
+ keypoint_score_type (str): Type of score to compute for keypoints. "heatmap" applies the heatmap
+ score to each keypoint. "center" applies the score of the center of each individual to
+ all of its keypoints. "combined" multiplies the score of the heatmap and individual
+ center for each keypoint.
+
+ Attributes:
+ num_animals (int): Number of animals in the project.
+ detection_threshold (float): Threshold for detection.
+ apply_sigmoid (bool): Apply sigmoid to heatmaps.
+ use_heatmap (bool): Use heatmap.
+ keypoint_score_type (str): Type of score to compute for keypoints. "heatmap" applies the heatmap
+ score to each keypoint. "center" applies the score of the center of each individual to
+ all of its keypoints. "combined" multiplies the score of the heatmap and individual
+ center for each keypoint.
+
+ Example:
+ # Create a DEKRPredictor instance with 2 animals.
+ predictor = DEKRPredictor(num_animals=2)
+
+ # Make a forward pass with outputs and scale factors.
+ outputs = (heatmaps, offsets) # tuple of heatmaps and offsets
+ scale_factors = (0.5, 0.5) # tuple of scale factors for the poses
+ poses_with_scores = predictor.forward(outputs, scale_factors)
+ """
+
+ default_init = {"apply_sigmoid": True, "detection_threshold": 0.01}
+
+ def __init__(
+ self,
+ num_animals: int,
+ detection_threshold: float = 0.01,
+ apply_sigmoid: bool = True,
+ clip_scores: bool = False,
+ use_heatmap: bool = True,
+ keypoint_score_type: str = "combined",
+ max_absorb_distance: int = 75,
+ ):
+ """
+ Args:
+ num_animals: Number of animals in the project.
+ detection_threshold: Threshold for detection
+ apply_sigmoid: Apply sigmoid to heatmaps
+ clip_scores: If a sigmoid is not applied, this can be used to clip scores
+ for predicted keypoints to values in [0, 1].
+ use_heatmap: Use heatmap to refine the keypoint predictions.
+ keypoint_score_type: Type of score to compute for keypoints. "heatmap"
+ applies the heatmap score to each keypoint. "center" applies the score
+ of the center of each individual to all of its keypoints. "combined"
+ multiplies the score of the heatmap and individual for each keypoint.
+ """
+ super().__init__()
+ self.num_animals = num_animals
+ self.detection_threshold = detection_threshold
+ self.apply_sigmoid = apply_sigmoid
+ self.clip_scores = clip_scores
+ self.use_heatmap = use_heatmap
+ self.keypoint_score_type = keypoint_score_type
+ if self.keypoint_score_type not in ("heatmap", "center", "combined"):
+ raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}")
+
+ # TODO: Set as in HRNet/DEKR configs. Define as a constant.
+ self.max_absorb_distance = max_absorb_distance
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ """Forward pass of DEKRPredictor.
+
+ Args:
+ stride: the stride of the model
+ outputs: outputs of the model heads (heatmap, locref)
+
+ Returns:
+ A dictionary containing a "poses" key with the output tensor as value, and
+ optionally a "unique_bodyparts" with the unique bodyparts tensor as value.
+
+ Example:
+ # Assuming you have 'outputs' (heatmaps and offsets) and 'scale_factors' for poses
+ poses_with_scores = predictor.forward(outputs, scale_factors)
+ """
+ heatmaps, offsets = outputs["heatmap"], outputs["offset"]
+ scale_factors = stride, stride
+
+ if self.apply_sigmoid:
+ heatmaps = F.sigmoid(heatmaps)
+
+ posemap = self.offset_to_pose(offsets)
+
+ batch_size, num_joints_with_center, h, w = heatmaps.shape
+ num_joints = num_joints_with_center - 1
+
+ center_heatmaps = heatmaps[:, -1]
+ pose_ind, ctr_scores = self.get_top_values(center_heatmaps)
+
+ posemap = posemap.permute(0, 2, 3, 1).view(batch_size, h * w, -1, 2)
+ poses = torch.zeros(batch_size, pose_ind.shape[1], num_joints, 2).to(
+ ctr_scores.device
+ )
+ for i in range(batch_size):
+ pose = posemap[i, pose_ind[i]]
+ poses[i] = pose
+
+ if self.use_heatmap:
+ poses = self._update_pose_with_heatmaps(poses, heatmaps[:, :-1])
+
+ if self.keypoint_score_type == "center":
+ score = (
+ ctr_scores.unsqueeze(-1)
+ .expand(batch_size, -1, num_joints)
+ .unsqueeze(-1)
+ )
+ elif self.keypoint_score_type == "heatmap":
+ score = self.get_heat_value(poses, heatmaps).unsqueeze(-1)
+ elif self.keypoint_score_type == "combined":
+ center_score = (
+ ctr_scores.unsqueeze(-1)
+ .expand(batch_size, -1, num_joints)
+ .unsqueeze(-1)
+ )
+ htmp_score = self.get_heat_value(poses, heatmaps).unsqueeze(-1)
+ score = center_score * htmp_score
+ else:
+ raise ValueError(f"Unknown keypoint score type: {self.keypoint_score_type}")
+
+ poses[:, :, :, 0] = (
+ poses[:, :, :, 0] * scale_factors[1] + 0.5 * scale_factors[1]
+ )
+ poses[:, :, :, 1] = (
+ poses[:, :, :, 1] * scale_factors[0] + 0.5 * scale_factors[0]
+ )
+
+ if self.clip_scores:
+ score = torch.clip(score, min=0, max=1)
+
+ poses_w_scores = torch.cat([poses, score], dim=3)
+ # self.pose_nms(heatmaps, poses_w_scores)
+ return {"poses": poses_w_scores}
+
+ def get_locations(
+ self, height: int, width: int, device: torch.device
+ ) -> torch.Tensor:
+ """Get locations for offsets.
+
+ Args:
+ height: Height of the offsets.
+ width: Width of the offsets.
+ device: Device to use.
+
+ Returns:
+ Offset locations.
+
+ Example:
+ # Assuming you have 'height', 'width', and 'device'
+ locations = predictor.get_locations(height, width, device)
+ """
+ shifts_x = torch.arange(0, width, step=1, dtype=torch.float32).to(device)
+ shifts_y = torch.arange(0, height, step=1, dtype=torch.float32).to(device)
+ shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")
+ shift_x = shift_x.reshape(-1)
+ shift_y = shift_y.reshape(-1)
+ locations = torch.stack((shift_x, shift_y), dim=1)
+ return locations
+
+ def get_reg_poses(self, offsets: torch.Tensor, num_joints: int) -> torch.Tensor:
+ """Get the regression poses from offsets.
+
+ Args:
+ offsets: Offsets tensor.
+ num_joint: Number of joints.
+
+ Returns:
+ Regression poses.
+
+ Example:
+ # Assuming you have 'offsets' tensor and 'num_joints'
+ regression_poses = predictor.get_reg_poses(offsets, num_joints)
+ """
+ batch_size, _, h, w = offsets.shape
+ offsets = offsets.permute(0, 2, 3, 1).reshape(batch_size, h * w, num_joints, 2)
+ locations = self.get_locations(h, w, offsets.device)
+ locations = locations[None, :, None, :].expand(batch_size, -1, num_joints, -1)
+ poses = locations - offsets
+
+ return poses
+
+ def offset_to_pose(self, offsets: torch.Tensor) -> torch.Tensor:
+ """Convert offsets to poses.
+
+ Args:
+ offsets: Offsets tensor.
+
+ Returns:
+ Poses from offsets.
+
+ Example:
+ # Assuming you have 'offsets' tensor
+ poses = predictor.offset_to_pose(offsets)
+ """
+ batch_size, num_offset, h, w = offsets.shape
+ num_joints = int(num_offset / 2)
+ reg_poses = self.get_reg_poses(offsets, num_joints)
+
+ reg_poses = (
+ reg_poses.contiguous()
+ .view(batch_size, h * w, 2 * num_joints)
+ .permute(0, 2, 1)
+ )
+ reg_poses = reg_poses.contiguous().view(batch_size, -1, h, w).contiguous()
+
+ return reg_poses
+
+ def max_pool(self, heatmap: torch.Tensor) -> torch.Tensor:
+ """Apply max pooling to the heatmap.
+
+ Args:
+ heatmap: Heatmap tensor.
+
+ Returns:
+ Max pooled heatmap.
+
+ Example:
+ # Assuming you have 'heatmap' tensor
+ max_pooled_heatmap = predictor.max_pool(heatmap)
+ """
+ pool1 = torch.nn.MaxPool2d(3, 1, 1)
+ pool2 = torch.nn.MaxPool2d(5, 1, 2)
+ pool3 = torch.nn.MaxPool2d(7, 1, 3)
+ map_size = (heatmap.shape[1] + heatmap.shape[2]) / 2.0
+ maxm = pool2(
+ heatmap
+ ) # Here I think pool 2 is a good match for default 17 pos_dist_tresh
+
+ return maxm
+
+ def get_top_values(
+ self, heatmap: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """Get top values from the heatmap.
+
+ Args:
+ heatmap: Heatmap tensor.
+
+ Returns:
+ Position indices and scores.
+
+ Example:
+ # Assuming you have 'heatmap' tensor
+ positions, scores = predictor.get_top_values(heatmap)
+ """
+ maximum = self.max_pool(heatmap)
+ maximum = torch.eq(maximum, heatmap)
+ heatmap *= maximum
+
+ batchsize, ny, nx = heatmap.shape
+ heatmap_flat = heatmap.reshape(batchsize, nx * ny)
+
+ scores, pos_ind = torch.topk(heatmap_flat, self.num_animals, dim=1)
+
+ return pos_ind, scores
+
+ ########## WIP to take heatmap into account for scoring ##########
+ def _update_pose_with_heatmaps(
+ self, _poses: torch.Tensor, kpt_heatmaps: torch.Tensor
+ ):
+ """If a heatmap center is close enough from the regressed point, the final prediction is the center of this heatmap
+
+ Args:
+ poses: poses tensor, shape (batch_size, num_animals, num_keypoints, 2)
+ kpt_heatmaps: heatmaps (does not contain the center heatmap), shape (batch_size, num_keypoints, h, w)
+ """
+ poses = _poses.clone()
+ maxm = self.max_pool(kpt_heatmaps)
+ maxm = torch.eq(maxm, kpt_heatmaps).float()
+ kpt_heatmaps *= maxm
+ batch_size, num_keypoints, h, w = kpt_heatmaps.shape
+ kpt_heatmaps = kpt_heatmaps.view(batch_size, num_keypoints, -1)
+ val_k, ind = kpt_heatmaps.topk(self.num_animals, dim=2)
+
+ x = ind % w
+ y = (ind / w).long()
+ heats_ind = torch.stack((x, y), dim=3)
+
+ for b in range(batch_size):
+ for i in range(num_keypoints):
+ heat_ind = heats_ind[b, i].float()
+ pose_ind = poses[b, :, i]
+ pose_heat_diff = pose_ind[:, None, :] - heat_ind
+ pose_heat_diff.pow_(2)
+ pose_heat_diff = pose_heat_diff.sum(2)
+ pose_heat_diff.sqrt_()
+ keep_ind = torch.argmin(pose_heat_diff, dim=1)
+
+ for p in range(keep_ind.shape[0]):
+ if pose_heat_diff[p, keep_ind[p]] < self.max_absorb_distance:
+ poses[b, p, i] = heat_ind[keep_ind[p]]
+
+ return poses
+
+ def get_heat_value(
+ self, pose_coords: torch.Tensor, heatmaps: torch.Tensor
+ ) -> torch.Tensor:
+ """Get heat values for pose coordinates and heatmaps.
+
+ Args:
+ pose_coords: Pose coordinates tensor (batch_size, num_animals, num_joints, 2)
+ heatmaps: Heatmaps tensor (batch_size, 1+num_joints, h, w).
+
+ Returns:
+ Heat values.
+
+ Example:
+ # Assuming you have 'pose_coords' and 'heatmaps' tensors
+ heat_values = predictor.get_heat_value(pose_coords, heatmaps)
+ """
+ h, w = heatmaps.shape[2:]
+ heatmaps_nocenter = heatmaps[:, :-1].flatten(
+ 2, 3
+ ) # (batch_size, num_joints, h*w)
+
+ # Predicted poses based on the offset can be outside of the image
+ x = torch.clamp(torch.floor(pose_coords[:, :, :, 0]), 0, w - 1).long()
+ y = torch.clamp(torch.floor(pose_coords[:, :, :, 1]), 0, h - 1).long()
+ keypoint_poses = (y * w + x).mT # (batch, num_joints, num_individuals)
+ heatscores = torch.gather(heatmaps_nocenter, 2, keypoint_poses)
+ return heatscores.mT # (batch, num_individuals, num_joints)
+
+ def pose_nms(self, heatmaps: torch.Tensor, poses: torch.Tensor):
+ """Non-Maximum Suppression (NMS) for regressed poses.
+
+ Args:
+ heatmaps: Heatmaps tensor.
+ poses: Pose proposals.
+
+ Returns:
+ None
+
+ Example:
+ # Assuming you have 'heatmaps' and 'poses' tensors
+ predictor.pose_nms(heatmaps, poses)
+ """
+ pose_scores = poses[:, :, :, 2]
+ pose_coords = poses[:, :, :, :2]
+
+ if pose_coords.shape[1] == 0:
+ return [], []
+
+ batch_size, num_people, num_joints, _ = pose_coords.shape
+ heatvals = self.get_heat_value(pose_coords, heatmaps)
+ heat_score = (torch.sum(heatvals, dim=1) / num_joints)[:, 0]
+
+ # return heat_score
+ # pose_score = pose_score*heatvals
+ # poses = torch.cat([pose_coord.cpu(), pose_score.cpu()], dim=2)
+
+ # keep_pose_inds = nms_core(cfg, pose_coord, heat_score)
+ # poses = poses[keep_pose_inds]
+ # heat_score = heat_score[keep_pose_inds]
+
+ # if len(keep_pose_inds) > cfg.DATASET.MAX_NUM_PEOPLE:
+ # heat_score, topk_inds = torch.topk(heat_score,
+ # cfg.DATASET.MAX_NUM_PEOPLE)
+ # poses = poses[topk_inds]
+
+ # poses = [poses.numpy()]
+ # scores = [i[:, 2].mean() for i in poses[0]]
+
+ # return poses, scores
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/identity_predictor.py b/deeplabcut/pose_estimation_pytorch/models/predictors/identity_predictor.py
new file mode 100644
index 0000000000..9f209df4e7
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/identity_predictor.py
@@ -0,0 +1,69 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Predictor to generate identity maps from head outputs"""
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as F
+
+from deeplabcut.pose_estimation_pytorch.models.predictors.base import (
+ BasePredictor,
+ PREDICTORS,
+)
+
+
+@PREDICTORS.register_module
+class IdentityPredictor(BasePredictor):
+ """Predictor to generate identity maps from head outputs
+
+ Attributes:
+ apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True.
+ """
+
+ def __init__(self, apply_sigmoid: bool = True):
+ """
+ Args:
+ apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True.
+ """
+ super().__init__()
+ self.apply_sigmoid = apply_sigmoid
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ """
+ Swaps the dimensions so the heatmap are (batch_size, h, w, num_individuals),
+ optionally applies a sigmoid to the heatmaps, and rescales it to be the size
+ of the original image (so that the identity scores of keypoints can be computed)
+
+ Args:
+ stride: the stride of the model
+ outputs: output of the model identity head, of shape (b, num_idv, w', h')
+
+ Returns:
+ A dictionary containing a "heatmap" key with the identity heatmap tensor as
+ value.
+ """
+ heatmaps = outputs["heatmap"]
+ h_out, w_out = heatmaps.shape[2:]
+ h_in, w_in = int(h_out * stride), int(w_out * stride)
+ heatmaps = F.resize(
+ heatmaps,
+ size=[h_in, w_in],
+ interpolation=F.InterpolationMode.BILINEAR,
+ antialias=True,
+ )
+ if self.apply_sigmoid:
+ heatmaps = self.sigmoid(heatmaps)
+
+ # permute to have shape (batch_size, h, w, num_individuals)
+ heatmaps = heatmaps.permute((0, 2, 3, 1))
+ return {"heatmap": heatmaps}
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/paf_predictor.py b/deeplabcut/pose_estimation_pytorch/models/predictors/paf_predictor.py
new file mode 100644
index 0000000000..f891524fa9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/paf_predictor.py
@@ -0,0 +1,378 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from numpy.typing import NDArray
+
+from deeplabcut.pose_estimation_pytorch.models.predictors.base import (
+ BasePredictor,
+ PREDICTORS,
+)
+from deeplabcut.core import inferenceutils
+
+Graph = list[tuple[int, int]]
+
+
+@PREDICTORS.register_module
+class PartAffinityFieldPredictor(BasePredictor):
+ """Predictor class for multiple animal pose estimation with part affinity fields.
+
+ Args:
+ num_animals: Number of animals in the project.
+ num_multibodyparts: Number of animal's body parts (ignoring unique body parts).
+ num_uniquebodyparts: Number of unique body parts. # FIXME - should not be needed here if we separate the unique bodypart head
+ graph: Part affinity field graph edges.
+ edges_to_keep: List of indices in `graph` of the edges to keep.
+ locref_stdev: Standard deviation for location refinement.
+ nms_radius: Radius of the Gaussian kernel.
+ sigma: Width of the 2D Gaussian distribution.
+ min_affinity: Minimal edge affinity to add a body part to an Assembly.
+
+ Returns:
+ Regressed keypoints from heatmaps, locref_maps and part affinity fields, as in Tensorflow maDLC.
+ """
+
+ default_init = {
+ "locref_stdev": 7.2801,
+ "nms_radius": 5,
+ "sigma": 1,
+ "min_affinity": 0.05,
+ }
+
+ def __init__(
+ self,
+ num_animals: int,
+ num_multibodyparts: int,
+ num_uniquebodyparts: int,
+ graph: Graph,
+ edges_to_keep: list[int],
+ locref_stdev: float,
+ nms_radius: int,
+ sigma: float,
+ min_affinity: float,
+ add_discarded: bool = False,
+ apply_sigmoid: bool = True,
+ clip_scores: bool = False,
+ force_fusion: bool = False,
+ return_preds: bool = False,
+ ):
+ """Initialize the PartAffinityFieldPredictor class.
+
+ Args:
+ num_animals: Number of animals in the project.
+ num_multibodyparts: Number of animal's body parts (ignoring unique body parts).
+ num_uniquebodyparts: Number of unique body parts.
+ graph: Part affinity field graph edges.
+ edges_to_keep: List of indices in `graph` of the edges to keep.
+ locref_stdev: Standard deviation for location refinement.
+ nms_radius: Radius of the Gaussian kernel.
+ sigma: Width of the 2D Gaussian distribution.
+ min_affinity: Minimal edge affinity to add a body part to an Assembly.
+ return_preds: Whether to return predictions alongside the animals' poses
+
+ Returns:
+ None
+ """
+ super().__init__()
+ self.num_animals = num_animals
+ self.num_multibodyparts = num_multibodyparts
+ self.num_uniquebodyparts = num_uniquebodyparts
+ self.graph = graph
+ self.edges_to_keep = edges_to_keep
+ self.locref_stdev = locref_stdev
+ self.nms_radius = nms_radius
+ self.return_preds = return_preds
+ self.sigma = sigma
+ self.apply_sigmoid = apply_sigmoid
+ self.clip_scores = clip_scores
+ self.sigmoid = torch.nn.Sigmoid()
+ self.assembler = inferenceutils.Assembler.empty(
+ num_animals,
+ n_multibodyparts=num_multibodyparts,
+ n_uniquebodyparts=num_uniquebodyparts,
+ graph=graph,
+ paf_inds=edges_to_keep,
+ min_affinity=min_affinity,
+ add_discarded=add_discarded,
+ force_fusion=force_fusion,
+ )
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ """Forward pass of PartAffinityFieldPredictor. Gets predictions from model output.
+
+ Args:
+ stride: the stride of the model
+ outputs: Output tensors from previous layers.
+ output = heatmaps, locref, pafs
+ heatmaps: torch.Tensor([batch_size, num_joints, height, width])
+ locref: torch.Tensor([batch_size, num_joints, height, width])
+
+ Returns:
+ A dictionary containing a "poses" key with the output tensor as value.
+
+ Example:
+ >>> predictor = PartAffinityFieldPredictor(num_animals=3, location_refinement=True, locref_stdev=7.2801)
+ >>> output = (torch.rand(32, 17, 64, 64), torch.rand(32, 34, 64, 64), torch.rand(32, 136, 64, 64))
+ >>> stride = 8
+ >>> poses = predictor.forward(stride, output)
+ """
+ heatmaps = outputs["heatmap"]
+ locrefs = outputs["locref"]
+ pafs = outputs["paf"]
+ scale_factors = stride, stride
+ batch_size, n_channels, height, width = heatmaps.shape
+
+ if self.apply_sigmoid:
+ heatmaps = self.sigmoid(heatmaps)
+
+ # Filter predicted heatmaps with a 2D Gaussian kernel as in:
+ # https://openaccess.thecvf.com/content_CVPR_2020/papers/Huang_The_Devil_Is_in_the_Details_Delving_Into_Unbiased_Data_CVPR_2020_paper.pdf
+ kernel = self.make_2d_gaussian_kernel(
+ sigma=self.sigma, size=self.nms_radius * 2 + 1
+ )[None, None]
+ kernel = kernel.repeat(n_channels, 1, 1, 1).to(heatmaps.device)
+ heatmaps = F.conv2d(
+ heatmaps, kernel, stride=1, padding="same", groups=n_channels
+ )
+
+ peaks = self.find_local_peak_indices_maxpool_nms(
+ heatmaps, self.nms_radius, threshold=0.01
+ )
+ if ~torch.any(peaks):
+ poses = -torch.ones(
+ (batch_size, self.num_animals, self.num_multibodyparts, 5)
+ )
+ results = dict(poses=poses)
+ if self.return_preds:
+ results["preds"] = [dict(coordinates=[[]], costs=[])],
+
+ return results
+
+ locrefs = locrefs.reshape(batch_size, n_channels, 2, height, width)
+ locrefs = locrefs * self.locref_stdev
+ pafs = pafs.reshape(batch_size, -1, 2, height, width)
+
+ graph = [self.graph[ind] for ind in self.edges_to_keep]
+ preds = self.compute_peaks_and_costs(
+ heatmaps,
+ locrefs,
+ pafs,
+ peaks,
+ graph,
+ self.edges_to_keep,
+ scale_factors,
+ n_id_channels=0, # FIXME Handle identity training
+ )
+ poses = -torch.ones((batch_size, self.num_animals, self.num_multibodyparts, 5))
+ poses_unique = -torch.ones((batch_size, 1, self.num_uniquebodyparts, 4))
+ for i, data_dict in enumerate(preds):
+ assemblies, unique = self.assembler._assemble(data_dict, ind_frame=0)
+ if assemblies is not None:
+ for j, assembly in enumerate(assemblies):
+ poses[i, j, :, :4] = torch.from_numpy(assembly.data)
+ poses[i, j, :, 4] = assembly.affinity
+ if unique is not None:
+ poses_unique[i, 0, :, :4] = torch.from_numpy(unique)
+
+ if self.clip_scores:
+ poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1)
+
+ out = {"poses": poses}
+ if self.return_preds:
+ out["preds"] = preds
+ return out
+
+ @staticmethod
+ def find_local_peak_indices_maxpool_nms(
+ input_: torch.Tensor, radius: int, threshold: float
+ ) -> torch.Tensor:
+ pooled = F.max_pool2d(input_, kernel_size=radius, stride=1, padding=radius // 2)
+ maxima = input_ * torch.eq(input_, pooled).float()
+ peak_indices = torch.nonzero(maxima >= threshold, as_tuple=False)
+ return peak_indices.int()
+
+ @staticmethod
+ def make_2d_gaussian_kernel(sigma: float, size: int) -> torch.Tensor:
+ k = torch.arange(-size // 2 + 1, size // 2 + 1, dtype=torch.float32) ** 2
+ k = F.softmax(-k / (2 * (sigma ** 2)), dim=0)
+ return torch.einsum("i,j->ij", k, k)
+
+ @staticmethod
+ def calc_peak_locations(
+ locrefs: torch.Tensor,
+ peak_inds_in_batch: torch.Tensor,
+ strides: tuple[float, float],
+ ) -> torch.Tensor:
+ s, b, r, c = peak_inds_in_batch.T
+ stride_y, stride_x = strides
+ strides = torch.Tensor((stride_x, stride_y)).to(locrefs.device)
+ off = locrefs[s, b, :, r, c]
+ loc = strides * peak_inds_in_batch[:, [3, 2]] + strides // 2 + off
+ return loc
+
+ @staticmethod
+ def compute_edge_costs(
+ pafs: NDArray,
+ peak_inds_in_batch: NDArray,
+ graph: Graph,
+ paf_inds: list[int],
+ n_bodyparts: int,
+ n_points: int = 10,
+ n_decimals: int = 3,
+ ) -> list[dict[int, NDArray]]:
+ # Clip peak locations to PAFs dimensions
+ h, w = pafs.shape[-2:]
+ peak_inds_in_batch[:, 2] = np.clip(peak_inds_in_batch[:, 2], 0, h - 1)
+ peak_inds_in_batch[:, 3] = np.clip(peak_inds_in_batch[:, 3], 0, w - 1)
+
+ n_samples = pafs.shape[0]
+ sample_inds = []
+ edge_inds = []
+ all_edges = []
+ all_peaks = []
+ for i in range(n_samples):
+ samples_i = peak_inds_in_batch[:, 0] == i
+ peak_inds = peak_inds_in_batch[samples_i, 1:]
+ if not np.any(peak_inds):
+ continue
+ peaks = peak_inds[:, 1:]
+ bpt_inds = peak_inds[:, 0]
+ idx = np.arange(peaks.shape[0])
+ idx_per_bpt = {j: idx[bpt_inds == j].tolist() for j in range(n_bodyparts)}
+ edges = []
+ for k, (s, t) in zip(paf_inds, graph):
+ inds_s = idx_per_bpt[s]
+ inds_t = idx_per_bpt[t]
+ if not (inds_s and inds_t):
+ continue
+ candidate_edges = ((i, j) for i in inds_s for j in inds_t)
+ edges.extend(candidate_edges)
+ edge_inds.extend([k] * len(inds_s) * len(inds_t))
+ if not edges:
+ continue
+ sample_inds.extend([i] * len(edges))
+ all_edges.extend(edges)
+ all_peaks.append(peaks[np.asarray(edges)])
+ if not all_peaks:
+ return [dict() for _ in range(n_samples)]
+
+ sample_inds = np.asarray(sample_inds, dtype=np.int32)
+ edge_inds = np.asarray(edge_inds, dtype=np.int32)
+ all_edges = np.asarray(all_edges, dtype=np.int32)
+ all_peaks = np.concatenate(all_peaks)
+ vecs_s = all_peaks[:, 0]
+ vecs_t = all_peaks[:, 1]
+ vecs = vecs_t - vecs_s
+ lengths = np.linalg.norm(vecs, axis=1).astype(np.float32)
+ lengths += np.spacing(1, dtype=np.float32)
+ xy = np.linspace(vecs_s, vecs_t, n_points, axis=1, dtype=np.int32)
+ y = pafs[
+ sample_inds.reshape((-1, 1)),
+ edge_inds.reshape((-1, 1)),
+ :,
+ xy[..., 0],
+ xy[..., 1],
+ ]
+ integ = np.trapz(y, xy[..., ::-1], axis=1)
+ affinities = np.linalg.norm(integ, axis=1).astype(np.float32)
+ affinities /= lengths
+ np.round(affinities, decimals=n_decimals, out=affinities)
+ np.round(lengths, decimals=n_decimals, out=lengths)
+
+ # Form cost matrices
+ all_costs = []
+ for i in range(n_samples):
+ samples_i_mask = sample_inds == i
+ costs = dict()
+ for k in paf_inds:
+ edges_k_mask = edge_inds == k
+ idx = np.flatnonzero(samples_i_mask & edges_k_mask)
+ s, t = all_edges[idx].T
+ n_sources = np.unique(s).size
+ n_targets = np.unique(t).size
+ costs[k] = dict()
+ costs[k]["m1"] = affinities[idx].reshape((n_sources, n_targets))
+ costs[k]["distance"] = lengths[idx].reshape((n_sources, n_targets))
+ all_costs.append(costs)
+
+ return all_costs
+
+ @staticmethod
+ def _linspace(start: torch.Tensor, stop: torch.Tensor, num: int) -> torch.Tensor:
+ # Taken from https://github.com/pytorch/pytorch/issues/61292#issue-937937159
+ steps = torch.linspace(0, 1, num, dtype=torch.float32, device=start.device)
+ steps = steps.reshape([-1, *([1] * start.ndim)])
+ out = start[None] + steps * (stop - start)[None]
+ return out.swapaxes(0, 1)
+
+ def compute_peaks_and_costs(
+ self,
+ heatmaps: torch.Tensor,
+ locrefs: torch.Tensor,
+ pafs: torch.Tensor,
+ peak_inds_in_batch: torch.Tensor,
+ graph: Graph,
+ paf_inds: list[int],
+ strides: tuple[float, float],
+ n_id_channels: int,
+ n_points: int = 10,
+ n_decimals: int = 3,
+ ) -> list[dict[str, NDArray]]:
+ n_samples, n_channels = heatmaps.shape[:2]
+ n_bodyparts = n_channels - n_id_channels
+ pos = self.calc_peak_locations(locrefs, peak_inds_in_batch, strides)
+ pos = np.round(pos.detach().cpu().numpy(), decimals=n_decimals)
+ heatmaps = heatmaps.detach().cpu().numpy()
+ pafs = pafs.detach().cpu().numpy()
+ peak_inds_in_batch = peak_inds_in_batch.detach().cpu().numpy()
+ costs = self.compute_edge_costs(
+ pafs, peak_inds_in_batch, graph, paf_inds, n_bodyparts, n_points, n_decimals
+ )
+ s, b, r, c = peak_inds_in_batch.T
+ prob = np.round(heatmaps[s, b, r, c], n_decimals).reshape((-1, 1))
+ if n_id_channels:
+ ids = np.round(heatmaps[s, -n_id_channels:, r, c], n_decimals)
+
+ peaks_and_costs = []
+ for i in range(n_samples):
+ xy = []
+ p = []
+ id_ = []
+ samples_i_mask = peak_inds_in_batch[:, 0] == i
+ for j in range(n_bodyparts):
+ bpts_j_mask = peak_inds_in_batch[:, 1] == j
+ idx = np.flatnonzero(samples_i_mask & bpts_j_mask)
+ xy.append(pos[idx])
+ p.append(prob[idx])
+ if n_id_channels:
+ id_.append(ids[idx])
+ dict_ = {"coordinates": (xy,), "confidence": p}
+ if costs is not None:
+ dict_["costs"] = costs[i]
+ if n_id_channels:
+ dict_["identity"] = id_
+ peaks_and_costs.append(dict_)
+
+ return peaks_and_costs
+
+ def set_paf_edges_to_keep(self, edge_indices: list[int]) -> None:
+ """Sets the PAF edge indices to use to assemble individuals
+
+ Args:
+ edge_indices: The indices of edges in the graph to keep.
+ """
+ self.edges_to_keep = edge_indices
+ self.assembler.paf_inds = edge_indices
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/sim_cc.py b/deeplabcut/pose_estimation_pytorch/models/predictors/sim_cc.py
new file mode 100644
index 0000000000..8b69eea23c
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/sim_cc.py
@@ -0,0 +1,124 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""SimCC predictor for the RTMPose model
+
+Based on the official ``mmpose`` SimCC codec and RTMCC head implementation. For more
+information, see .
+"""
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.predictors.base import (
+ BasePredictor,
+ PREDICTORS,
+)
+
+
+@PREDICTORS.register_module
+class SimCCPredictor(BasePredictor):
+ """Class used to make pose predictions from RTMPose head outputs
+
+ The RTMPose model uses coordinate classification for pose estimation. For more
+ information, see "SimCC: a Simple Coordinate Classification Perspective for Human
+ Pose Estimation" () and "RTMPose: Real-Time
+ Multi-Person Pose Estimation based on MMPose" ().
+
+ Args:
+ simcc_split_ratio: The split ratio of pixels, as described in SimCC.
+ """
+
+ def __init__(self, simcc_split_ratio: float = 2.0) -> None:
+ super().__init__()
+ self.simcc_split_ratio = simcc_split_ratio
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ simcc_x = outputs["x"].detach().cpu().numpy()
+ simcc_y = outputs["y"].detach().cpu().numpy()
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
+
+ if keypoints.ndim == 2:
+ keypoints = keypoints[None, :]
+ scores = scores[None, :]
+
+ keypoints /= self.simcc_split_ratio
+ scores = scores.reshape((*scores.shape, -1))
+ keypoints_with_score = np.concatenate([keypoints, scores], axis=-1)
+ keypoints_with_score = torch.tensor(keypoints_with_score).unsqueeze(1)
+ return dict(poses=keypoints_with_score)
+
+
+def get_simcc_maximum(
+ simcc_x: np.ndarray,
+ simcc_y: np.ndarray,
+ apply_softmax: bool = False,
+) -> tuple[np.ndarray, np.ndarray]:
+ """Get maximum response location and value from SimCC representations.
+
+ Note:
+ instance number: N
+ num_keypoints: K
+ heatmap height: H
+ heatmap width: W
+
+ Args:
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
+ apply_softmax (bool): whether to apply softmax on the heatmap.
+ Defaults to False.
+
+ Returns:
+ tuple:
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
+ (K, 2) or (N, K, 2)
+ - vals (np.ndarray): values of maximum heatmap responses in shape
+ (K,) or (N, K)
+ """
+
+ assert isinstance(simcc_x, np.ndarray), "simcc_x should be numpy.ndarray"
+ assert isinstance(simcc_y, np.ndarray), "simcc_y should be numpy.ndarray"
+ assert simcc_x.ndim == 2 or simcc_x.ndim == 3, f"Invalid shape {simcc_x.shape}"
+ assert simcc_y.ndim == 2 or simcc_y.ndim == 3, f"Invalid shape {simcc_y.shape}"
+ assert simcc_x.ndim == simcc_y.ndim, f"{simcc_x.shape} != {simcc_y.shape}"
+
+ if simcc_x.ndim == 3:
+ N, K, Wx = simcc_x.shape
+ simcc_x = simcc_x.reshape(N * K, -1)
+ simcc_y = simcc_y.reshape(N * K, -1)
+ else:
+ N = None
+
+ if apply_softmax:
+ simcc_x = simcc_x - np.max(simcc_x, axis=1, keepdims=True)
+ simcc_y = simcc_y - np.max(simcc_y, axis=1, keepdims=True)
+ ex, ey = np.exp(simcc_x), np.exp(simcc_y)
+ simcc_x = ex / np.sum(ex, axis=1, keepdims=True)
+ simcc_y = ey / np.sum(ey, axis=1, keepdims=True)
+
+ x_locs = np.argmax(simcc_x, axis=1)
+ y_locs = np.argmax(simcc_y, axis=1)
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
+ max_val_x = np.amax(simcc_x, axis=1)
+ max_val_y = np.amax(simcc_y, axis=1)
+
+ mask = max_val_x > max_val_y
+ max_val_x[mask] = max_val_y[mask]
+ vals = max_val_x
+ locs[vals <= 0.0] = -1
+
+ if N:
+ locs = locs.reshape(N, K, 2)
+ vals = vals.reshape(N, K)
+
+ return locs, vals
diff --git a/deeplabcut/pose_estimation_pytorch/models/predictors/single_predictor.py b/deeplabcut/pose_estimation_pytorch/models/predictors/single_predictor.py
new file mode 100644
index 0000000000..51a9be11b4
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/predictors/single_predictor.py
@@ -0,0 +1,162 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from typing import Tuple
+
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.predictors.base import (
+ BasePredictor,
+ PREDICTORS,
+)
+
+
+@PREDICTORS.register_module
+class HeatmapPredictor(BasePredictor):
+ """Predictor class for pose estimation from heatmaps (and optionally locrefs).
+
+ Args:
+ location_refinement: Enable location refinement.
+ locref_std: Standard deviation for location refinement.
+ apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True.
+
+ Returns:
+ Regressed keypoints from heatmaps and locref_maps of baseline DLC model (ResNet + Deconv).
+ """
+
+ def __init__(
+ self,
+ apply_sigmoid: bool = True,
+ clip_scores: bool = False,
+ location_refinement: bool = True,
+ locref_std: float = 7.2801,
+ ):
+ """
+ Args:
+ apply_sigmoid: Apply sigmoid to heatmaps. Defaults to True.
+ clip_scores: If a sigmoid is not applied, this can be used to clip scores
+ for predicted keypoints to values in [0, 1].
+ location_refinement : Enable location refinement.
+ locref_std: Standard deviation for location refinement.
+ """
+ super().__init__()
+ self.apply_sigmoid = apply_sigmoid
+ self.clip_scores = clip_scores
+ self.sigmoid = torch.nn.Sigmoid()
+ self.location_refinement = location_refinement
+ self.locref_std = locref_std
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor]:
+ """Forward pass of SinglePredictor. Gets predictions from model output.
+
+ Args:
+ stride: the stride of the model
+ outputs: output of the model heads (heatmap, locref)
+
+ Returns:
+ A dictionary containing a "poses" key with the output tensor as value.
+
+ Example:
+ >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801)
+ >>> stride = 8
+ >>> output = {"heatmap": torch.rand(32, 17, 64, 64), "locref": torch.rand(32, 17, 64, 64)}
+ >>> poses = predictor.forward(stride, output)
+ """
+ heatmaps = outputs["heatmap"]
+ scale_factors = stride, stride
+
+ if self.apply_sigmoid:
+ heatmaps = self.sigmoid(heatmaps)
+
+ heatmaps = heatmaps.permute(0, 2, 3, 1)
+ batch_size, height, width, num_joints = heatmaps.shape
+
+ locrefs = None
+ if self.location_refinement:
+ locrefs = outputs["locref"]
+ locrefs = locrefs.permute(0, 2, 3, 1).reshape(
+ batch_size, height, width, num_joints, 2
+ )
+ locrefs = locrefs * self.locref_std
+
+ poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors)
+
+ if self.clip_scores:
+ poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1)
+
+ return {"poses": poses}
+
+ def get_top_values(
+ self, heatmap: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the top values from the heatmap.
+
+ Args:
+ heatmap: Heatmap tensor.
+
+ Returns:
+ Y and X indices of the top values.
+
+ Example:
+ >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801)
+ >>> heatmap = torch.rand(32, 17, 64, 64)
+ >>> Y, X = predictor.get_top_values(heatmap)
+ """
+ batchsize, ny, nx, num_joints = heatmap.shape
+ heatmap_flat = heatmap.reshape(batchsize, nx * ny, num_joints)
+ heatmap_top = torch.argmax(heatmap_flat, dim=1)
+ y, x = heatmap_top // nx, heatmap_top % nx
+ return y, x
+
+ def get_pose_prediction(
+ self, heatmap: torch.Tensor, locref: torch.Tensor | None, scale_factors
+ ) -> torch.Tensor:
+ """Gets the pose prediction given the heatmaps and locref.
+
+ Args:
+ heatmap: Heatmap tensor with the following format (batch_size, height, width, num_joints)
+ locref: Locref tensor with the following format (batch_size, height, width, num_joints, 2)
+ scale_factors: Scale factors for the poses.
+
+ Returns:
+ Pose predictions of the format: (batch_size, num_people = 1, num_joints, 3)
+
+ Example:
+ >>> predictor = HeatmapPredictor(location_refinement=True, locref_std=7.2801)
+ >>> heatmap = torch.rand(32, 17, 64, 64)
+ >>> locref = torch.rand(32, 17, 64, 64, 2)
+ >>> scale_factors = (0.5, 0.5)
+ >>> poses = predictor.get_pose_prediction(heatmap, locref, scale_factors)
+ """
+ y, x = self.get_top_values(heatmap)
+
+ batch_size, num_joints = x.shape
+
+ dz = torch.zeros((batch_size, 1, num_joints, 3), device=heatmap.device)
+ for b in range(batch_size):
+ for j in range(num_joints):
+ dz[b, 0, j, 2] = heatmap[b, y[b, j], x[b, j], j]
+ if locref is not None:
+ dz[b, 0, j, :2] = locref[b, y[b, j], x[b, j], j, :]
+
+ x, y = torch.unsqueeze(x, 1), torch.unsqueeze(y, 1)
+
+ x = x * scale_factors[1] + 0.5 * scale_factors[1] + dz[:, :, :, 0]
+ y = y * scale_factors[0] + 0.5 * scale_factors[0] + dz[:, :, :, 1]
+
+ pose = torch.zeros((batch_size, 1, num_joints, 3), device=heatmap.device)
+ pose[:, :, :, 0] = x
+ pose[:, :, :, 1] = y
+ pose[:, :, :, 2] = dz[:, :, :, 2]
+ return pose
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/__init__.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/__init__.py
new file mode 100644
index 0000000000..7b7389588b
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/__init__.py
@@ -0,0 +1,28 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.models.target_generators.base import (
+ TARGET_GENERATORS,
+ BaseGenerator,
+ SequentialGenerator,
+)
+from deeplabcut.pose_estimation_pytorch.models.target_generators.dekr_targets import (
+ DEKRGenerator,
+)
+from deeplabcut.pose_estimation_pytorch.models.target_generators.heatmap_targets import (
+ HeatmapGaussianGenerator,
+ HeatmapPlateauGenerator,
+)
+from deeplabcut.pose_estimation_pytorch.models.target_generators.pafs_targets import (
+ PartAffinityFieldGenerator,
+)
+from deeplabcut.pose_estimation_pytorch.models.target_generators.sim_cc import (
+ SimCCGenerator,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/base.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/base.py
new file mode 100644
index 0000000000..6842277331
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/base.py
@@ -0,0 +1,90 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+TARGET_GENERATORS = Registry("target_generators", build_func=build_from_cfg)
+
+
+class BaseGenerator(ABC, nn.Module): # TODO: Should this really be a module?
+ """Generates target maps from ground truth annotations to train models
+
+ The outputs of the target generator are used to compute losses for model heads. If
+ the head outputs "heatmap" and "offset" tensors, then the corresponding generator
+ must output target "heatmap" and "offset" tensors. The targets themselves are
+ dictionaries, and passed as keyword-arguments to the criterions. This allows to
+ pass masks to the criterions.
+
+ Generally, this means that for each head output (such as "heatmap"), a dict will be
+ generated with a "target" key (for the target heatmap) and optionally a "weights"
+ key (see the WeightedCriterion classes).
+ """
+
+ def __init__(self, label_keypoint_key: str = "keypoints"):
+ super().__init__()
+ self.label_keypoint_key = label_keypoint_key
+
+ @abstractmethod
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ """Generates targets
+
+ Args:
+ stride: the stride of the model
+ outputs: output of a model head
+ labels: the labels for the inputs (each tensor should have shape (b, ...))
+
+ Returns:
+ a dictionary mapping the heads to the inputs of the criterion
+ {
+ "heatmap": {
+ "target": heatmaps,
+ "weights": heatmap_weights,
+ },
+ "locref": {
+ "target": locref_map,
+ "weights": locref_weights,
+ }
+ }
+ """
+
+
+@TARGET_GENERATORS.register_module
+class SequentialGenerator(BaseGenerator):
+ def __init__(self, generators: list[dict], label_keypoint_key: str = "keypoints"):
+ super().__init__(label_keypoint_key)
+ self._generators = [TARGET_GENERATORS.build(dict_) for dict_ in generators]
+
+ @property
+ def generators(self):
+ return self._generators
+
+ def forward(
+ self, stride: int, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ dict_ = {}
+ for gen in self.generators:
+ dict_.update(gen(stride, outputs, labels))
+ return dict_
+
+ def __repr__(self):
+ generators_repr = ", ".join(repr(gen) for gen in self._generators)
+ return (
+ f"<{self.__class__.__name__}(generators=[{generators_repr}], "
+ f"label_keypoint_key='{self.label_keypoint_key}')>"
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/dekr_targets.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/dekr_targets.py
new file mode 100644
index 0000000000..5a21e49185
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/dekr_targets.py
@@ -0,0 +1,219 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.base import (
+ BaseGenerator,
+ TARGET_GENERATORS,
+)
+
+
+@TARGET_GENERATORS.register_module
+class DEKRGenerator(BaseGenerator):
+ """
+ Generate ground truth target for DEKR model training based on:
+ Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression
+ Zigang Geng, Ke Sun, Bin Xiao, Zhaoxiang Zhang, Jingdong Wang, CVPR 2021
+ Code based on:
+ https://github.com/HRNet/DEKR
+ """
+
+ def __init__(
+ self, num_joints: int, pos_dist_thresh: int, bg_weight: float = 0.1, **kwargs
+ ):
+ """
+ Args:
+ num_joints: number of keypoints
+ pos_dist_thresh: 3*std of the gaussian
+ bg_weight:background weight. Defaults to 0.1.
+ """
+ super().__init__(**kwargs)
+
+ self.num_joints = num_joints
+ self.num_heatmaps = self.num_joints + 1
+ self.pos_dist_thresh = pos_dist_thresh
+ self.bg_weight = bg_weight
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ """
+ Given the annotations and predictions of your keypoints, this function returns the targets,
+ a dictionary containing the heatmaps, locref_maps and locref_masks.
+ Args:
+ stride: the stride of the model
+ outputs: output of each model head
+ labels: the labels for the inputs (each tensor should have shape (b, ...))
+
+ Returns:
+ The targets for the DEKR heatmap and offset heads:
+ {
+ "heatmap": {
+ "target": heatmaps,
+ "weights": heatmap_weights,
+ },
+ "offset": {
+ "target": offset_map,
+ "weights": offset_weights,
+ }
+ }
+
+ Examples:
+ input:
+ labels = {"keypoints":torch.randint(1,min(image_size),(batch_size, num_animals, num_joints, 2))}
+ prediction = [torch.rand((batch_size, num_joints, image_size[0], image_size[1]))]
+ image_size = (256, 256)
+ output:
+ targets = {
+ "heatmap": {"target": heatmaps, "weights": heatmap_weights},
+ "offset": {"target": offset_map, "weights": offset_masks}
+ }
+ """
+ stride_y, stride_x = stride, stride
+ batch_size, _, output_h, output_w = outputs["heatmap"].shape
+ coords = labels[self.label_keypoint_key].cpu().numpy()
+ area = labels["area"].cpu().numpy()
+
+ assert (
+ self.num_joints + 1 == coords.shape[2]
+ ), f"the number of joints should be {coords.shape}"
+
+ # TODO make it possible to differentiate between center sigma and other sigmas
+ scale = max(1 / stride_x, 1 / stride_y)
+ sgm, ct_sgm = (self.pos_dist_thresh / 2) * scale, self.pos_dist_thresh * scale
+ radius = self.pos_dist_thresh * scale
+
+ heatmap_shape = batch_size, self.num_heatmaps, output_h, output_w
+ heatmaps = np.zeros(heatmap_shape, dtype=np.float32)
+ heatmap_weights = 2 * np.ones(heatmap_shape, dtype=np.float32)
+
+ offset_shape = batch_size, self.num_joints * 2, output_h, output_w
+ offset_map = np.zeros(offset_shape, dtype=np.float32)
+ weight_map = np.zeros(offset_shape, dtype=np.float32)
+
+ area_map = np.zeros((batch_size, output_h, output_w), dtype=np.float32)
+ for b in range(batch_size):
+ for person_id, p in enumerate(coords[b]):
+ idx_center = len(p) - 1
+ ct_x = int(p[-1, 0])
+ ct_y = int(p[-1, 1])
+
+ ct_x_sm = (ct_x - stride_x / 2) / stride_x
+ ct_y_sm = (ct_y - stride_y / 2) / stride_y
+ for idx, pt in enumerate(p):
+ if pt[-1] == -1:
+ # full gradient masking
+ heatmap_weights[b, idx] = 0.0
+ continue
+ elif pt[-1] <= 0:
+ continue
+
+ if idx == idx_center:
+ sigma = ct_sgm
+ else:
+ sigma = sgm
+
+ x, y = pt[0], pt[1]
+ x_sm, y_sm = (
+ (x - stride_x / 2) / stride_x,
+ (y - stride_y / 2) / stride_y,
+ )
+
+ if x_sm < 0 or y_sm < 0 or x_sm >= output_w or y_sm >= output_h:
+ continue
+
+ # HEATMAP COMPUTATION
+ ul = (
+ int(np.floor(x_sm - 3 * sigma - 1)),
+ int(np.floor(y_sm - 3 * sigma - 1)),
+ )
+ br = (
+ int(np.ceil(x_sm + 3 * sigma + 2)),
+ int(np.ceil(y_sm + 3 * sigma + 2)),
+ )
+
+ cc, dd = max(0, ul[0]), min(br[0], output_w)
+ aa, bb = max(0, ul[1]), min(br[1], output_h)
+
+ joint_rg = np.zeros((bb - aa, dd - cc))
+ for sy in range(aa, bb):
+ for sx in range(cc, dd):
+ joint_rg[sy - aa, sx - cc] = dekr_heatmap_val(
+ sigma, sx, sy, x_sm, y_sm
+ )
+
+ heatmaps[b, idx, aa:bb, cc:dd] = np.maximum(
+ heatmaps[b, idx, aa:bb, cc:dd], joint_rg
+ )
+ heatmap_weights[b, idx, aa:bb, cc:dd] = 1.0
+
+ # OFFSET COMPUTATION
+ if idx != idx_center:
+ start_x = max(int(ct_x_sm - radius), 0)
+ start_y = max(int(ct_y_sm - radius), 0)
+ end_x = min(int(ct_x_sm + radius), output_w)
+ end_y = min(int(ct_y_sm + radius), output_h)
+
+ for pos_x in range(start_x, end_x):
+ for pos_y in range(start_y, end_y):
+ offset_x = pos_x - x_sm
+ offset_y = pos_y - y_sm
+ if (
+ offset_map[b, idx * 2, pos_y, pos_x] != 0
+ or offset_map[b, idx * 2 + 1, pos_y, pos_x] != 0
+ ):
+ if area_map[b, pos_y, pos_x] < area[b, person_id]:
+ continue
+ offset_map[b, idx * 2, pos_y, pos_x] = offset_x
+ offset_map[b, idx * 2 + 1, pos_y, pos_x] = offset_y
+ # TODO find a decent constant make weights vary giving animal area
+ weight_map[b, idx * 2, pos_y, pos_x] = 1.0 / np.sqrt(
+ area[b, person_id]
+ )
+ weight_map[
+ b, idx * 2 + 1, pos_y, pos_x
+ ] = 1.0 / np.sqrt(area[b, person_id])
+ area_map[b, pos_y, pos_x] = area[b, person_id]
+
+ heatmap_weights[heatmap_weights == 2] = self.bg_weight
+ return {
+ "heatmap": {
+ "target": torch.tensor(heatmaps, device=outputs["heatmap"].device),
+ "weights": torch.tensor(
+ heatmap_weights, device=outputs["heatmap"].device
+ ),
+ },
+ "offset": {
+ "target": torch.tensor(offset_map, device=outputs["offset"].device),
+ "weights": torch.tensor(weight_map, device=outputs["offset"].device),
+ },
+ }
+
+
+def dekr_heatmap_val(sigma: float, x: float, y: float, x0: float, y0: float) -> float:
+ """
+ Calculates the corresponding heat value of point (x,y) given the heat distribution centered
+ at (x0,y0) and spread value of sigma.
+
+ Args:
+ sigma: controls the spread or width of the heat distribution
+ x: x coord of a point on the image grid
+ y: y coord of a point on the image grid
+ x0: x center coordinate of the heat distribution
+ y0: y center coordinate of the heat distribution
+
+ Returns:
+ g: calculated heat value represents the intensity of the heat at a given position
+ """
+ return np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/heatmap_targets.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/heatmap_targets.py
new file mode 100644
index 0000000000..3c8d8e2f61
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/heatmap_targets.py
@@ -0,0 +1,332 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import abstractmethod
+from enum import Enum
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.base import (
+ BaseGenerator,
+ TARGET_GENERATORS,
+)
+
+
+class HeatmapGenerator(BaseGenerator):
+ """Abstract class to generate target heatmap targets (with/without locref)
+
+ Can generate target heatmaps either for pose estimation (one keypoint), or for
+ individual identification.
+
+ This class is abstract, and heatmap targets should be generated through its
+ subclasses (such as HeatmapPlateauGenerator)
+ """
+
+ class Mode(Enum):
+ """
+ KEYPOINT generates one heatmap per type of keypoint (for pose estimation heads)
+ INDIVIDUAL generates one heatmap per individual (for identification heads)
+ """
+
+ INDIVIDUAL = "INDIVIDUAL"
+ KEYPOINT = "KEYPOINT"
+
+ @classmethod
+ def _missing_(cls, value):
+ if isinstance(value, str):
+ value = value.upper()
+ for member in cls:
+ if member.value == value:
+ return member
+ return None
+
+ def __init__(
+ self,
+ num_heatmaps: int,
+ pos_dist_thresh: int,
+ heatmap_mode: str | Mode = Mode.KEYPOINT,
+ gradient_masking: bool = False,
+ background_weight: float = 0.1,
+ generate_locref: bool = True,
+ locref_std: float = 7.2801,
+ **kwargs,
+ ):
+ """
+ Args:
+ num_heatmaps: the number of heatmaps to generate
+ pos_dist_thresh: 3*std of the gaussian. We think of dist_thresh as a radius
+ and std is a 'diameter'.
+ mode: the mode to generate heatmaps for
+ gradient_masking: Whether to mask the gradient when a bodypart is undefined
+ (has visibility ``0`` in the dataset). WARNING: Do not set this option
+ for bottom-up models, as a keypoint missing for one animal means the
+ gradients for all animals will be set to 0 for that image.
+ Gradients for inputs that have the visibility flag ``-1`` will always be
+ masked, as this flag indicates that the keypoint is not defined for the
+ image.
+ background_weight: If ``gradient_masking == True`, the weight to apply to
+ the loss for background pixels.
+ learned_id_target: whether to generate the heatmap for keypoints
+ or for learned IDs
+ generate_locref: whether to generate location refinement maps
+ locref_std: the STD for the location refinement maps, if defined
+
+ Examples:
+ input:
+ locref_std = 7.2801, default value in pytorch config
+ num_joints = 6
+ po_dist_thresh = 17, default value in pytorch config
+ """
+ super().__init__(**kwargs)
+ self.num_heatmaps = num_heatmaps
+ self.dist_thresh = float(pos_dist_thresh)
+ self.dist_thresh_sq = self.dist_thresh**2
+ self.std = 2 * self.dist_thresh / 3
+
+ if isinstance(heatmap_mode, str):
+ heatmap_mode = HeatmapGenerator.Mode(heatmap_mode)
+ self.heatmap_mode = heatmap_mode
+
+ self.gradient_masking = gradient_masking
+ self.background_weight = background_weight
+
+ self.generate_locref = generate_locref
+ self.locref_scale = 1.0 / locref_std
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ """
+ Given the annotations and predictions of your keypoints, this function returns the targets,
+ a dictionary containing the heatmaps, locref_maps and locref_masks.
+
+ Args:
+ stride: the stride of the model
+ outputs: output of each model head
+ labels: the labels for the inputs (each tensor should have shape (b, ...))
+
+ Returns:
+ The targets for the heatmap and locref heads:
+ {
+ "heatmap": {
+ "target": heatmaps,
+ "weights": heatmap_weights,
+ },
+ "locref": { # optional
+ "target": locref_map,
+ "weights": locref_weights,
+ }
+ }
+
+ Examples:
+ input:
+ annotations = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_joints, 2)
+ )
+ }
+ image_size = (256, 256)
+ model_stride = 4
+ output:
+ targets = {
+ "heatmap": {
+ "target": array of shape (batch_size, 64, 64, num_joints),
+ "weights": array of shape (batch_size, 64, 64, num_joints),
+ },
+ "locref": {
+ "target": array of shape (batch_size, 64, 64, num_joints),
+ "weights": array of shape (batch_size, 64, 64, num_joints),
+ }
+ }
+ """
+ stride_y, stride_x = stride, stride
+ batch_size, _, height, width = outputs["heatmap"].shape
+ coords = labels[self.label_keypoint_key].cpu().numpy()
+ if len(coords.shape) == 3: # for single animal: add individual dimension
+ coords = coords.reshape((batch_size, 1, *coords.shape[1:]))
+
+ if self.heatmap_mode == HeatmapGenerator.Mode.KEYPOINT:
+ # transpose the individuals and keypoints to iterate over bodyparts
+ coords = coords.transpose((0, 2, 1, 3))
+ if self.heatmap_mode == HeatmapGenerator.Mode.INDIVIDUAL:
+ # re-order the individuals to always have the same order
+ # TODO: Optimize
+ sorted_coords = -np.ones_like(coords)
+ for i, batch_individuals in enumerate(labels["individual_ids"]):
+ for j, individual_id in enumerate(batch_individuals):
+ if individual_id >= 0:
+ sorted_coords[i, individual_id] = coords[i, j]
+ coords = sorted_coords
+
+ map_size = batch_size, height, width
+ heatmap = np.zeros((*map_size, self.num_heatmaps), dtype=np.float32)
+ weights = np.ones(
+ (batch_size, self.num_heatmaps, height, width),
+ dtype=np.float32,
+ )
+
+ locref_map, locref_mask = None, None
+ if self.generate_locref:
+ locref_map = np.zeros((*map_size, self.num_heatmaps * 2), dtype=np.float32)
+ locref_mask = np.zeros_like(locref_map, dtype=int)
+
+ grid = np.mgrid[:height, :width].transpose((1, 2, 0))
+ grid[:, :, 0] = grid[:, :, 0] * stride_y + stride_y / 2
+ grid[:, :, 1] = grid[:, :, 1] * stride_x + stride_x / 2
+
+ # heatmap (batch_size, height, width, num_kpts)
+ # coords (batch_size, num_kpts, num_individuals, 3)
+ for b in range(batch_size):
+ for heatmap_idx, group_keypoints in enumerate(coords[b]):
+ for keypoint in group_keypoints:
+ if self.gradient_masking and keypoint[-1] == 0:
+ # apply background weight if keypoints are missing
+ weights[b, heatmap_idx] = self.background_weight
+ elif keypoint[-1] == -1:
+ # always mask weights when the keypoint is undefined
+ weights[b, heatmap_idx] = 0.0
+ elif keypoint[-1] > 0:
+ # keypoint visible
+ self.update(
+ heatmap=heatmap[b, :, :, heatmap_idx],
+ grid=grid,
+ keypoint=keypoint[..., :2],
+ locref_map=self.get_locref(locref_map, b, heatmap_idx),
+ locref_mask=self.get_locref(locref_mask, b, heatmap_idx),
+ )
+
+ hm_device = outputs["heatmap"].device
+ heatmap = heatmap.transpose((0, 3, 1, 2))
+ target = {
+ "heatmap": {
+ "target": torch.tensor(heatmap, device=hm_device),
+ "weights": torch.tensor(weights, device=hm_device),
+ }
+ }
+
+ if self.generate_locref:
+ locref_map = locref_map.transpose((0, 3, 1, 2))
+ locref_mask = locref_mask.transpose((0, 3, 1, 2))
+ target["locref"] = {
+ "target": torch.tensor(locref_map, device=outputs["locref"].device),
+ "weights": torch.tensor(locref_mask, device=outputs["locref"].device),
+ }
+
+ return target
+
+ def get_locref(
+ self,
+ locref_map_or_mask: np.ndarray | None,
+ batch_idx: int,
+ heatmap_idx: int,
+ ) -> np.ndarray | None:
+ """
+ Args:
+ locref_map_or_mask: the locref array to return (either the map or mask), of
+ shape (batch_size, height, width, num_heatmaps)
+ batch_idx: the index of the batch
+ heatmap_idx: the index of the heatmap for which we want the location
+ refinement maps or masks
+
+ Returns:
+ the location refinement maps/masks of shape (height, width, 2)
+ """
+ if not self.generate_locref:
+ return None
+
+ start_idx = 2 * heatmap_idx
+ end_idx = start_idx + 2
+ return locref_map_or_mask[batch_idx, :, :, start_idx:end_idx]
+
+ @abstractmethod
+ def update(
+ self,
+ heatmap: np.ndarray,
+ grid: np.mgrid,
+ keypoint: np.ndarray,
+ locref_map: np.ndarray | None,
+ locref_mask: np.ndarray | None,
+ ) -> None:
+ """
+ Updates the heatmap and locref targets in-place following an update rule (e.g.,
+ Gaussian or Plateau).
+
+ Args:
+ heatmap: the heatmap to update of shape (height, width)
+ grid: the grid for ???
+ keypoint: the keypoint with which to update the maps
+ locref_map: the location refinement maps of shape (height, width, 2), if
+ self.generate_locref = True
+ locref_mask: the location refinement masks of shape (height, width, 2), if
+ self.generate_locref = True
+ """
+ raise NotImplementedError
+
+
+@TARGET_GENERATORS.register_module
+class HeatmapGaussianGenerator(HeatmapGenerator):
+ """Generates gaussian heatmaps (and locref) targets from keypoints"""
+
+ def update(
+ self,
+ heatmap: np.ndarray,
+ grid: np.mgrid,
+ keypoint: np.ndarray,
+ locref_map: np.ndarray | None,
+ locref_mask: np.ndarray | None,
+ ) -> None:
+ """Updates the heatmap (and locref if defined) with gaussian values"""
+ # revert keypoints to follow image convention: from x,y to y,x
+ keypoint = keypoint.copy()[::-1]
+
+ dist = np.linalg.norm(grid - keypoint, axis=2) ** 2
+ heatmap_j = np.exp(-dist / (2 * self.std**2))
+ heatmap[:, :] = np.maximum(heatmap, heatmap_j)
+
+ if locref_map is not None:
+ dx = keypoint[1] - grid.copy()[:, :, 1]
+ dy = keypoint[0] - grid.copy()[:, :, 0]
+ locref_map[:, :, 0] = dx * self.locref_scale
+ locref_map[:, :, 1] = dy * self.locref_scale
+
+ if locref_mask is not None:
+ locref_mask[dist <= self.dist_thresh_sq] = 1
+
+
+@TARGET_GENERATORS.register_module
+class HeatmapPlateauGenerator(HeatmapGenerator):
+ """Generates plateau heatmaps (and locref) targets from keypoints"""
+
+ def update(
+ self,
+ heatmap: np.ndarray,
+ grid: np.mgrid,
+ keypoint: np.ndarray,
+ locref_map: np.ndarray | None,
+ locref_mask: np.ndarray | None,
+ ) -> None:
+ """Updates the heatmap (and locref if defined) with plateau values"""
+ # revert keypoints to follow image convention: from x,y to y,x
+ keypoint = keypoint.copy()[::-1]
+ dist = np.sum((grid - keypoint) ** 2, axis=2)
+ mask = dist <= self.dist_thresh_sq
+ heatmap[mask] = 1
+
+ if locref_map is not None:
+ dx = keypoint[1] - grid.copy()[:, :, 1]
+ dy = keypoint[0] - grid.copy()[:, :, 0]
+ locref_map[mask, 0] = (dx * self.locref_scale)[mask]
+ locref_map[mask, 1] = (dy * self.locref_scale)[mask]
+
+ if locref_mask is not None:
+ locref_mask[mask] = 1
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/pafs_targets.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/pafs_targets.py
new file mode 100644
index 0000000000..6c6d15b5a0
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/pafs_targets.py
@@ -0,0 +1,109 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from math import sqrt
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.base import (
+ BaseGenerator,
+ TARGET_GENERATORS,
+)
+
+
+@TARGET_GENERATORS.register_module
+class PartAffinityFieldGenerator(BaseGenerator):
+ """
+ Generate part affinity field targets from ground truth keypoints in order
+ to train baseline multi-animal deeplabcut model (ResNet + Deconv)
+ """
+
+ def __init__(self, graph: list[list[int, int]], width: float):
+ """
+ Args:
+ graph: list of pairs of keypoint indices forming
+ the graph edges
+ width: width of the vector field in pixels
+
+ Examples:
+ input:
+ graph = [(0, 1), (0, 2), (1, 2)]
+ width = 20.0, default value in pytorch config
+ """
+ super().__init__()
+ self.graph = graph
+ self.width = width
+ self.num_limbs = len(graph)
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ stride_y, stride_x = stride, stride
+ batch_size, _, height, width = outputs["heatmap"].shape
+ coords = labels[self.label_keypoint_key].cpu().numpy()
+
+ paf_map = np.zeros(
+ (batch_size, height, width, self.num_limbs * 2), dtype=np.float32
+ )
+ grid = np.mgrid[:height, :width].transpose((1, 2, 0))
+ grid[:, :, 0] = grid[:, :, 0] * stride_y + stride_y / 2
+ grid[:, :, 1] = grid[:, :, 1] * stride_x + stride_x / 2
+ y, x = np.rollaxis(grid, 2)
+
+ for b in range(batch_size):
+ for _, kpts_animal in enumerate(coords[b]):
+ visible = set(np.flatnonzero(kpts_animal[..., -1] > 0))
+ kpts_animal = kpts_animal[..., :2]
+ for l, (bp1, bp2) in enumerate(self.graph):
+ if not (bp1 in visible and bp2 in visible):
+ continue
+
+ j1_x, j1_y = kpts_animal[bp1]
+ j2_x, j2_y = kpts_animal[bp2]
+ vec_x = j2_x - j1_x
+ vec_y = j2_y - j1_y
+ dist = sqrt(vec_x ** 2 + vec_y ** 2)
+ if dist > 0:
+ vec_x_norm = vec_x / dist
+ vec_y_norm = vec_y / dist
+ vec = [
+ vec_x_norm * j1_x + vec_y_norm * j1_y,
+ vec_x_norm * j2_x + vec_y_norm * j2_y,
+ ]
+ vec_ortho = j1_y * vec_x_norm - j1_x * vec_y_norm
+
+ distance_along = vec_x_norm * x + vec_y_norm * y
+ distance_across = (
+ ((y * vec_x_norm - x * vec_y_norm) - vec_ortho)
+ * 1.0
+ / self.width
+ )
+
+ mask1 = (distance_along >= min(vec)) & (
+ distance_along <= max(vec)
+ )
+ distance_across_abs = np.abs(distance_across)
+ mask2 = distance_across_abs <= 1
+ mask = mask1 & mask2
+ temp = 1 - distance_across_abs[mask]
+ paf_map[b, mask, l * 2 + 0] = vec_x_norm * temp
+ paf_map[b, mask, l * 2 + 1] = vec_y_norm * temp
+
+ paf_map = paf_map.transpose((0, 3, 1, 2))
+ return {
+ "paf": {
+ "target": torch.tensor(
+ paf_map, device=outputs["paf"].device
+ )
+ }
+ }
diff --git a/deeplabcut/pose_estimation_pytorch/models/target_generators/sim_cc.py b/deeplabcut/pose_estimation_pytorch/models/target_generators/sim_cc.py
new file mode 100644
index 0000000000..f0060261ae
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/target_generators/sim_cc.py
@@ -0,0 +1,231 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Modified SimCC target generator for the RTMPose model
+
+Based on the official ``mmpose`` SimCC codec and RTMCC head implementation. For more
+information, see .
+"""
+from __future__ import annotations
+
+from itertools import product
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.base import (
+ BaseGenerator,
+ TARGET_GENERATORS,
+)
+
+
+@TARGET_GENERATORS.register_module
+class SimCCGenerator(BaseGenerator):
+ """Class used generate targets from RTMPose head outputs
+
+ The RTMPose model uses coordinate classification for pose estimation. For more
+ information, see "SimCC: a Simple Coordinate Classification Perspective for Human
+ Pose Estimation" () and "RTMPose: Real-Time
+ Multi-Person Pose Estimation based on MMPose" ().
+
+ Args:
+ input_size: The size of images given to the pose estimation model.
+ smoothing_type: Smoothing strategy ("gaussian" or "standard")
+ sigma: The sigma value in the Gaussian SimCC label. If a single value, used for
+ both x and y. If two values, the sigmas for (x, y).
+ simcc_split_ratio: The split ratio of pixels, as described in SimCC.
+ label_smooth_weight: Label Smoothing weight.
+ normalize: Normalize the heatmaps before returning.
+ **kwargs,
+ """
+
+ def __init__(
+ self,
+ input_size: tuple[int, int],
+ smoothing_type: str = "gaussian",
+ sigma: float | int | tuple[float, ...] = 6.0,
+ simcc_split_ratio: float = 2.0,
+ label_smooth_weight: float = 0.0,
+ normalize: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.input_size = input_size
+ self.smoothing_type = smoothing_type
+ self.simcc_split_ratio = simcc_split_ratio
+ self.label_smooth_weight = label_smooth_weight
+ self.normalize = normalize
+
+ if isinstance(sigma, (float, int)):
+ self.sigma = np.array([sigma, sigma])
+ else:
+ self.sigma = np.array(sigma)
+
+ if self.smoothing_type not in {"gaussian", "standard"}:
+ raise ValueError(
+ f"{self.__class__.__name__} got invalid `smoothing_type` value"
+ f"{self.smoothing_type}. Should be one of "
+ '{"gaussian", "standard"}'
+ )
+
+ if self.smoothing_type == "gaussian" and self.label_smooth_weight > 0:
+ raise ValueError(
+ "Attribute `label_smooth_weight` is only " "used for `standard` mode."
+ )
+
+ if self.label_smooth_weight < 0.0 or self.label_smooth_weight > 1.0:
+ raise ValueError("`label_smooth_weight` should be in range [0, 1]")
+
+ if self.smoothing_type == "gaussian":
+ self.generator = self._generate_gaussian
+ elif self.smoothing_type == "standard":
+ self.generator = self._generate_standard
+ else:
+ raise ValueError(
+ f"{self.__class__.__name__} got invalid `smoothing_type` value"
+ f"{self.smoothing_type}. Should be one of "
+ '{"gaussian", "standard"}'
+ )
+
+ def forward(
+ self, stride: float, outputs: dict[str, torch.Tensor], labels: dict
+ ) -> dict[str, dict[str, torch.Tensor]]:
+ device = outputs["x"].device
+ keypoints = labels[self.label_keypoint_key].cpu().numpy()
+ batch_size = len(keypoints)
+
+ if len(keypoints.shape) == 3: # for single animal: add individual dimension
+ keypoints = keypoints.reshape((batch_size, 1, *keypoints.shape[1:]))
+
+ xs, ys, ws = [], [], []
+ for batch_keypoints in keypoints:
+ keypoints = batch_keypoints[:, :, :2]
+ keypoints_visible = batch_keypoints[:, :, 2]
+ x_labels, y_labels, weights = self.generator(keypoints, keypoints_visible)
+ xs.append(x_labels)
+ ys.append(y_labels)
+ ws.append(weights)
+
+ x_labels = np.stack(xs)
+ y_labels = np.stack(ys)
+ weights = np.stack(ws)
+ return dict(
+ x=dict(
+ target=torch.tensor(x_labels, device=device),
+ weights=torch.tensor(weights, device=device),
+ ),
+ y=dict(
+ target=torch.tensor(y_labels, device=device),
+ weights=torch.tensor(weights, device=device),
+ ),
+ )
+
+ def _generate_standard(
+ self, keypoints: np.ndarray, keypoints_visible: np.ndarray | None = None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Encoding keypoints into SimCC labels with Standard Label Smoothing.
+
+ Labels will be one-hot vectors if self.label_smooth_weight==0.0
+ """
+ N, K, _ = keypoints.shape
+ w, h = self.input_size
+ W = np.around(w * self.simcc_split_ratio).astype(int)
+ H = np.around(h * self.simcc_split_ratio).astype(int)
+
+ keypoints_split, keypoint_weights = self._map_coordinates(
+ keypoints, keypoints_visible
+ )
+
+ target_x = np.zeros((N, K, W), dtype=np.float32)
+ target_y = np.zeros((N, K, H), dtype=np.float32)
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabeled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ # get center coordinates
+ mu_x, mu_y = keypoints_split[n, k].astype(np.int64)
+
+ # detect abnormal coords and assign the weight 0
+ if mu_x >= W or mu_y >= H or mu_x < 0 or mu_y < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ if self.label_smooth_weight > 0:
+ target_x[n, k] = self.label_smooth_weight / (W - 1)
+ target_y[n, k] = self.label_smooth_weight / (H - 1)
+
+ target_x[n, k, mu_x] = 1.0 - self.label_smooth_weight
+ target_y[n, k, mu_y] = 1.0 - self.label_smooth_weight
+
+ return target_x, target_y, keypoint_weights
+
+ def _map_coordinates(
+ self, keypoints: np.ndarray, keypoints_visible: np.ndarray | None = None
+ ) -> tuple[np.ndarray, np.ndarray]:
+ """Mapping keypoint coordinates into SimCC space"""
+ keypoints_split = keypoints.copy()
+ # set non-visible keypoints to 0; deals with NaNs
+ keypoints_split[keypoints_visible <= 0] = 0
+ keypoints_split = np.around(keypoints_split * self.simcc_split_ratio)
+ keypoints_split = keypoints_split.astype(np.int64)
+ keypoint_weights = (keypoints_visible > 0).astype(keypoints_split.dtype)
+ return keypoints_split, keypoint_weights
+
+ def _generate_gaussian(
+ self, keypoints: np.ndarray, keypoints_visible: np.ndarray | None = None
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """Encoding keypoints into SimCC labels with Gaussian Label Smoothing"""
+ N, K, _ = keypoints.shape
+ w, h = self.input_size
+ W = np.around(w * self.simcc_split_ratio).astype(int)
+ H = np.around(h * self.simcc_split_ratio).astype(int)
+
+ keypoints_split, keypoint_weights = self._map_coordinates(
+ keypoints, keypoints_visible
+ )
+
+ target_x = np.zeros((N, K, W), dtype=np.float32)
+ target_y = np.zeros((N, K, H), dtype=np.float32)
+
+ # 3-sigma rule
+ radius = self.sigma * 3
+
+ # xy grid
+ x = np.arange(0, W, 1, dtype=np.float32)
+ y = np.arange(0, H, 1, dtype=np.float32)
+
+ for n, k in product(range(N), range(K)):
+ # skip unlabeled keypoints
+ if keypoints_visible[n, k] < 0.5:
+ continue
+
+ mu = keypoints_split[n, k]
+
+ # check that the gaussian has in-bounds part
+ left, top = mu - radius
+ right, bottom = mu + radius + 1
+
+ if left >= W or top >= H or right < 0 or bottom < 0:
+ keypoint_weights[n, k] = 0
+ continue
+
+ mu_x, mu_y = mu
+
+ target_x[n, k] = np.exp(-((x - mu_x) ** 2) / (2 * self.sigma[0] ** 2))
+ target_y[n, k] = np.exp(-((y - mu_y) ** 2) / (2 * self.sigma[1] ** 2))
+
+ if self.normalize:
+ norm_value = self.sigma * np.sqrt(np.pi * 2)
+ target_x /= norm_value[0]
+ target_y /= norm_value[1]
+
+ return target_x, target_y, keypoint_weights
diff --git a/deeplabcut/pose_estimation_pytorch/models/weight_init.py b/deeplabcut/pose_estimation_pytorch/models/weight_init.py
new file mode 100644
index 0000000000..9d2a4d539f
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/models/weight_init.py
@@ -0,0 +1,105 @@
+"""Ways to initialize weights for PyTorch modules"""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+import torch.nn as nn
+
+from deeplabcut.pose_estimation_pytorch.registry import build_from_cfg, Registry
+
+
+def _build_weight_init(cfg: str | dict, **kwargs) -> BaseWeightInitializer:
+ """Builds a BaseWeightInitializer using its config or the name of the initializer
+
+ Args:
+ cfg: Either the name of the initializer (e.g. 'normal') or the config
+ **kwargs: Other parameters given by the Registry.
+
+ Returns:
+ the built BaseWeightInitializer
+ """
+ if isinstance(cfg, str):
+ cfg = {"type": cfg.title().replace("_", "")}
+ return build_from_cfg(cfg, **kwargs)
+
+
+WEIGHT_INIT = Registry("weight_init", build_func=_build_weight_init)
+
+
+class BaseWeightInitializer(ABC):
+ """Class to used to initialize model weights"""
+
+ @abstractmethod
+ def init_weights(self, model: nn.Module) -> None:
+ """Initializes weights for a model.
+
+ Args:
+ model: The model for which to initialize weights
+ """
+
+
+@WEIGHT_INIT.register_module
+class Normal(BaseWeightInitializer):
+ """Class to used to initialize model weights using a normal distribution
+
+ Weights are initialized with a normal distribution, and biases are initialized to 0.
+
+ Attributes:
+ std: the standard deviation to use to initialize weights
+ """
+
+ def __init__(self, std: float = 0.001):
+ self.std = std
+
+ def init_weights(self, model: nn.Module) -> None:
+ for name, module in model.named_parameters():
+ if "bias" in name:
+ nn.init.constant_(module, 0)
+ else:
+ nn.init.normal_(module, std=self.std)
+
+
+@WEIGHT_INIT.register_module
+class Dekr(BaseWeightInitializer):
+ """Class to used to initialize model weights in the same way as DEKR
+
+ Attributes:
+ std: the standard deviation to use to initialize weights
+ """
+
+ def __init__(self, std: float = 0.001):
+ self.std = std
+
+ def init_weights(self, model: nn.Module) -> None:
+ for name, module in model.named_parameters():
+ if "bias" in name:
+ nn.init.constant_(module, 0)
+ else:
+ nn.init.normal_(module, std=self.std)
+
+ if hasattr(module, "transform_matrix_conv"):
+ nn.init.constant_(module.transform_matrix_conv.weight, 0)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.transform_matrix_conv.bias, 0)
+ if hasattr(module, "translation_conv"):
+ nn.init.constant_(module.translation_conv.weight, 0)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.translation_conv.bias, 0)
+
+
+@WEIGHT_INIT.register_module
+class Rtmpose(BaseWeightInitializer):
+ """Class to used to initialize head weights in the same way as RTMPose"""
+
+ def init_weights(self, model: nn.Module) -> None:
+ for module in model.modules():
+ if isinstance(module, nn.Conv2d):
+ nn.init.normal_(module.weight, std=0.001)
+ nn.init.constant_(module.bias, 0)
+ elif isinstance(module, nn.BatchNorm2d):
+ nn.init.constant_(module.weight, 1)
+ nn.init.constant_(module.bias, 1)
+ elif isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, std=0.01)
+ nn.init.constant_(module.bias, 0)
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/__init__.py b/deeplabcut/pose_estimation_pytorch/modelzoo/__init__.py
new file mode 100644
index 0000000000..e8232cd895
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/__init__.py
@@ -0,0 +1,18 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ download_super_animal_snapshot,
+ get_snapshot_folder_path,
+ get_super_animal_model_config_path,
+ get_super_animal_project_config_path,
+ get_super_animal_snapshot_path,
+ load_super_animal_config,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/config.py b/deeplabcut/pose_estimation_pytorch/modelzoo/config.py
new file mode 100644
index 0000000000..e595b2f621
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/config.py
@@ -0,0 +1,188 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Methods to create the configuration files to fine-tune SuperAnimal models"""
+from __future__ import annotations
+
+import os
+from pathlib import Path
+
+from ruamel.yaml import YAML
+
+import deeplabcut.pose_estimation_pytorch.config.utils as config_utils
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.core.config import (
+ read_config_as_dict,
+ write_config,
+)
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ get_super_animal_model_config_path,
+ get_super_animal_project_config_path,
+)
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+def make_super_animal_finetune_config(
+ weight_init: WeightInitialization,
+ project_config: dict,
+ pose_config_path: str | Path,
+ model_name: str,
+ detector_name: str | None,
+ save: bool = False,
+) -> dict:
+ """
+ Creates a PyTorch pose configuration file to finetune a SuperAnimal model on a
+ downstream project.
+
+ Args:
+ weight_init: The weight initialization configuration.
+ project_config: The project configuration.
+ pose_config_path: The path where the pose configuration file will be saved
+ model_name: The type of neural net to finetune.
+ detector_name: The type of detector to use for the SuperAnimal model. If None is
+ given, the model will be set to a Bottom-Up framework.
+ save: Whether to save the model configuration file to the ``pose_config_path``.
+
+ Returns:
+ The generated pose configuration file.
+
+ Raises:
+ ValueError: If `weight_init.with_decoder = False`. This method only creates
+ configs to fine-tune SuperAnimal models. Call `make_pytorch_pose_config`
+ to create configuration files for transfer learning.
+ """
+ bodyparts = af.get_bodyparts(project_config)
+ if weight_init.dataset is None:
+ raise ValueError(
+ "You must set the ``WeightInitialization.dataset`` when fine-tuning "
+ "SuperAnimal models."
+ )
+
+ if not weight_init.with_decoder:
+ raise ValueError(
+ "Can only call ``make_super_animal_finetune_config`` when "
+ f" `with_decoder=True`, but you had {weight_init}. Please set "
+ "`with_decoder=True` to fine-tune a model or call "
+ "`make_pytorch_pose_config` to create a transfer learning "
+ "pose configuration file."
+ )
+
+ converted_bodyparts = bodyparts
+ if weight_init.bodyparts is not None:
+ assert len(weight_init.bodyparts) == len(weight_init.conversion_array)
+ converted_bodyparts = weight_init.bodyparts
+ elif len(bodyparts) != len(weight_init.conversion_array):
+ raise ValueError(
+ "You don't have the same number of bodyparts in your project config as "
+ f"number of entries your conversion array ({bodyparts} vs "
+ f"{weight_init.conversion_array}). If you're fine-tuning from "
+ "SuperAnimal on a subset of your bodyparts, you must specify which "
+ "ones in `WeightInitialization.bodyparts`. This should be done "
+ "automatically when creating the `weight_init` with "
+ "`WeightInitialization.build`."
+ )
+
+ # Load the exact pose configuration file for the model to fine-tune
+ pose_config = create_config_from_modelzoo(
+ super_animal=weight_init.dataset,
+ model_name=model_name,
+ detector_name=detector_name,
+ converted_bodyparts=converted_bodyparts,
+ weight_init=weight_init,
+ project_config=project_config,
+ pose_config_path=pose_config_path,
+ )
+ if save:
+ write_config(pose_config_path, pose_config, overwrite=True)
+
+ return pose_config
+
+
+def create_config_from_modelzoo(
+ super_animal: str,
+ model_name: str,
+ detector_name: str | None,
+ converted_bodyparts: list[str],
+ weight_init: WeightInitialization,
+ project_config: dict,
+ pose_config_path: str | Path,
+) -> dict:
+ """Creates a model configuration file to fine-tune a SuperAnimal model
+
+ Args:
+ super_animal: The SuperAnimal dataset on which the model was trained.
+ model_name: The type of neural net to finetune.
+ detector_name: The type of detector to use for the SuperAnimal model. If None is
+ given, the model will be set to a Bottom-Up framework.
+ converted_bodyparts: The project bodyparts that the model will learn.
+ weight_init: The weight initialization to use.
+ project_config: The project configuration.
+ pose_config_path: The path where the pose configuration file will be saved.
+
+ Returns:
+ The generated pose configuration file.
+ """
+ # load the model configuration
+ model_cfg = read_config_as_dict(
+ get_super_animal_model_config_path(model_name)
+ )
+ if detector_name is None:
+ model_cfg["method"] = Task.BOTTOM_UP.aliases[0].lower()
+ # Use default bottom-up image augmentation if no detector is given (the collate
+ # function might be needed).
+ config_dir = config_utils.get_config_folder_path()
+ aug = read_config_as_dict(config_dir / "base" / "aug_default.yaml")
+ model_cfg["data"]["train"] = aug["train"]
+ else:
+ model_cfg["method"] = Task.TOP_DOWN.aliases[0].lower()
+ model_cfg["detector"] = read_config_as_dict(
+ get_super_animal_model_config_path(detector_name)
+ )
+
+ # use SuperAnimal bodyparts
+ if weight_init.memory_replay:
+ super_animal_project_config = read_config_as_dict(
+ get_super_animal_project_config_path(super_animal)
+ )
+ converted_bodyparts = super_animal_project_config["bodyparts"]
+
+ model_cfg["net_type"] = model_name
+ model_cfg["metadata"] = {
+ "project_path": project_config["project_path"],
+ "pose_config_path": str(pose_config_path),
+ "bodyparts": converted_bodyparts,
+ "unique_bodyparts": [],
+ "individuals": project_config.get("individuals", ["animal"]),
+ "with_identity": False,
+ }
+
+ model_cfg["model"] = config_utils.replace_default_values(
+ model_cfg["model"], num_bodyparts=len(converted_bodyparts)
+ )
+ model_cfg["train_settings"]["weight_init"] = weight_init.to_dict()
+
+ # sort first-level keys to make it prettier
+ return dict(sorted(model_cfg.items()))
+
+
+def write_pytorch_config_for_memory_replay(config_path, shuffle, pytorch_config):
+ cfg = af.read_config(config_path)
+ trainIndex = 0
+ dlc_proj_root = Path(config_path).parent
+ model_folder = dlc_proj_root / af.get_model_folder(
+ cfg["TrainingFraction"][trainIndex], shuffle, cfg, engine=Engine.PYTORCH
+ )
+ os.makedirs(model_folder / "train", exist_ok=True)
+ out_path = model_folder / "train" / "pytorch_config.yaml"
+ with open(str(out_path), "w") as f:
+ yaml = YAML()
+ yaml.dump(pytorch_config, f)
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/inference.py b/deeplabcut/pose_estimation_pytorch/modelzoo/inference.py
new file mode 100644
index 0000000000..66d77b1551
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/inference.py
@@ -0,0 +1,188 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import json
+import os
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+
+from deeplabcut.modelzoo.utils import get_super_animal_scorer, get_superanimal_colormaps
+from deeplabcut.pose_estimation_pytorch.apis.videos import (
+ create_df_from_prediction,
+ video_inference,
+ VideoIterator,
+)
+from deeplabcut.pose_estimation_pytorch.apis.utils import get_inference_runners
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ raise_warning_if_called_directly,
+)
+from deeplabcut.utils.make_labeled_video import create_video
+
+
+class NumpyEncoder(json.JSONEncoder):
+ """Special json encoder for numpy types"""
+
+ def default(self, obj):
+ if isinstance(obj, np.ndarray):
+ return obj.tolist() # Convert ndarray to list
+ return json.JSONEncoder.default(self, obj)
+
+
+def construct_bodypart_names(max_individuals, bodyparts):
+ multianimalbodyparts = []
+ for i in range(max_individuals):
+ for bodypart in bodyparts:
+ multianimalbodyparts.append(f"{bodypart}_{i}")
+ return multianimalbodyparts
+
+
+def _video_inference_superanimal(
+ video_paths: Union[str, list],
+ superanimal_name: str,
+ model_cfg: dict,
+ model_snapshot_path: str | Path,
+ detector_snapshot_path: str | Path | None,
+ max_individuals: int,
+ pcutoff: float,
+ batch_size: int = 1,
+ detector_batch_size: int = 1,
+ cropping: list[int] | None = None,
+ dest_folder: Optional[str] = None,
+ output_suffix: str = "",
+ plot_bboxes: bool = True,
+ bboxes_pcutoff: float = 0.9,
+) -> dict:
+ """
+ Perform inference on a video using a superanimal model from the model zoo specified by `superanimal_name`.
+ During inference, the video is analyzed using the specified model and the results are saved in the specified
+ destination folder. The predictions are saved in the form of a .h5 file. The video with the predictions is saved
+ in the form of a .mp4 file.
+
+ WARNING: This function is an internal utility function and should not be
+ called directly. It is designed to be used by deeplabcut.modelzoo.api.video_inference.py
+
+ Args:
+ video_paths: Path to the video to be analyzed or list of paths to videos to be
+ analyzed
+ superanimal_name: Name of the SuperAnimal project (e.g. superanimal_quadruped)
+ model_cfg: The name of the pose model architecture to use for inference.
+ model_snapshot_path: The path to the pose model snapshot to use for inference.
+ detector_snapshot_path: The path to the detector snapshot to use for inference.
+ max_individuals: Maximum number of individuals in the video
+ pcutoff: Cutoff for cutting off the predicted keypoints with probability lower
+ than pcutoff
+ batch_size: The batch size to use for video inference.
+ cropping: List of cropping coordinates as [x1, x2, y1, y2]. Note that the same
+ cropping parameters will then be used for all videos. If different video
+ crops are desired, run ``video_inference_superanimal`` on individual videos
+ with the corresponding cropping coordinates.
+ detector_batch_size: The batch size to use for the detector for video inference.
+ dest_folder: Destination folder for the results. If not specified, the
+ results are saved in the same folder as the video. Defaults to None.
+ output_suffix: The suffix to add to output file names (e.g. _before_adapt)
+
+ Returns:
+ results: Dictionary with the result pd.DataFrame for each video
+
+ Raises:
+ Warning: If the function is called directly.
+ """
+ raise_warning_if_called_directly()
+ pose_runner, detector_runner = get_inference_runners(
+ model_config=model_cfg,
+ snapshot_path=model_snapshot_path,
+ max_individuals=max_individuals,
+ num_bodyparts=len(model_cfg["metadata"]["bodyparts"]),
+ num_unique_bodyparts=0,
+ batch_size=batch_size,
+ detector_batch_size=detector_batch_size,
+ detector_path=detector_snapshot_path,
+ )
+ results = {}
+
+ if isinstance(video_paths, str):
+ video_paths = [video_paths]
+
+ if dest_folder is None:
+ dest_folder = Path(video_paths[0]).parent
+
+ if not os.path.exists(dest_folder):
+ os.makedirs(dest_folder)
+
+ for video_path in video_paths:
+ print(f"Processing video {video_path}")
+
+ dlc_scorer = get_super_animal_scorer(
+ superanimal_name, model_snapshot_path, detector_snapshot_path
+ )
+
+ output_prefix = f"{Path(video_path).stem}_{dlc_scorer}"
+ output_path = Path(dest_folder)
+ output_h5 = Path(output_path) / f"{output_prefix}.h5"
+
+ output_json = output_h5.with_suffix(".json")
+ if len(output_suffix) > 0:
+ output_json = output_json.with_stem(output_h5.stem + output_suffix)
+
+ video = VideoIterator(video_path, cropping=cropping)
+ predictions = video_inference(
+ video,
+ pose_runner=pose_runner,
+ detector_runner=detector_runner,
+ )
+
+ bbox_keys_in_predictions = {"bboxes", "bbox_scores"}
+ bboxes_list = [
+ {key: value for key, value in p.items() if key in bbox_keys_in_predictions}
+ for i, p in enumerate(predictions)
+ ]
+
+ bbox = cropping
+ if cropping is None:
+ vid_w, vid_h = video.dimensions
+ bbox = (0, vid_w, 0, vid_h)
+
+ print(f"Saving results to {dest_folder}")
+ df = create_df_from_prediction(
+ predictions=predictions,
+ dlc_scorer=dlc_scorer,
+ multi_animal=True,
+ model_cfg=model_cfg,
+ output_path=output_path,
+ output_prefix=output_prefix,
+ )
+
+ results[video_path] = df
+ with open(output_json, "w") as f:
+ json.dump(predictions, f, cls=NumpyEncoder)
+
+ output_video = output_path / f"{output_prefix}_labeled.mp4"
+ if len(output_suffix) > 0:
+ output_video = output_video.with_stem(output_video.stem + output_suffix)
+
+ superanimal_colormaps = get_superanimal_colormaps()
+ colormap = superanimal_colormaps[superanimal_name]
+ create_video(
+ video_path,
+ output_h5,
+ pcutoff=pcutoff,
+ fps=video.fps,
+ bbox=bbox,
+ cmap=colormap,
+ output_path=str(output_video),
+ plot_bboxes=plot_bboxes,
+ bboxes_list=bboxes_list,
+ bboxes_pcutoff=bboxes_pcutoff,
+ )
+ print(f"Video with predictions was saved as {output_path}")
+
+ return results
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/memory_replay.py b/deeplabcut/pose_estimation_pytorch/modelzoo/memory_replay.py
new file mode 100644
index 0000000000..002f74bb20
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/memory_replay.py
@@ -0,0 +1,370 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import json
+import os
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from scipy.spatial import distance
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.modelzoo.generalized_data_converter.datasets import (
+ COCOPoseDataset,
+ MaDLCPoseDataset,
+ SingleDLCPoseDataset,
+)
+from deeplabcut.pose_estimation_pytorch.apis.utils import get_inference_runners
+from deeplabcut.pose_estimation_pytorch.data.dlcloader import DLCLoader
+from deeplabcut.pose_estimation_pytorch.modelzoo import (
+ get_super_animal_project_config_path,
+)
+from deeplabcut.utils.pseudo_label import calculate_iou
+
+
+def get_pose_predictions(
+ loader: DLCLoader,
+ images: list[str],
+ bboxes: dict[str, list],
+ superanimal_name: str,
+ model_snapshot_path: str | Path,
+ detector_snapshot_path: str | Path,
+ max_individuals: int,
+ device: str | None = None,
+) -> dict[str, dict]:
+ """Gets predictions made by a SuperAnimal model on a DeepLabCut project
+
+ Args:
+ loader: The path to the root of the project.
+ images: The images on which to run inference with the SuperAnimal model.
+ bboxes: The ground truth bounding boxes for each image in the project.
+ superanimal_name: The name of the SuperAnimal dataset being used.
+ model_snapshot_path: The path to the SuperAnimal pose snapshot.
+ detector_snapshot_path: The path to the SuperAnimal detector snapshot.
+ max_individuals: The maximum number of individuals to detect per image.
+ device: The CUDA device to use.
+
+ Returns:
+ The predictions made by the SuperAnimal model on each image in the images list.
+ """
+ model_name = detector_snapshot_path.stem + "-" + model_snapshot_path.stem
+ predictions_folder = (
+ loader.project_path / "memory_replay" / superanimal_name / model_name
+ )
+ predictions_folder.mkdir(exist_ok=True, parents=True)
+ predictions_file = predictions_folder / "pseudo-labels.json"
+
+ # COCO-format annotations file containing predictions made by the SuperAnimal model
+ sa_predictions = {}
+ if predictions_file.exists():
+ with open(predictions_file, "r") as f:
+ raw_sa_predictions = json.load(f)
+
+ # parse predictions to convert lists to numpy arrays
+ for image, predictions in raw_sa_predictions.items():
+ sa_predictions[image] = {
+ "bodyparts": np.array(predictions["bodyparts"]),
+ "bboxes": np.array(predictions["bboxes"]),
+ # "bbox_scores": np.array(predictions["bbox_scores"]),
+ }
+
+ # get images that need to be processed
+ processed_images = set(sa_predictions.keys())
+ images_to_process = [image for image in (set(images) - processed_images)]
+
+ # if all images have been processed by the SuperAnimal model, return the predictions
+ if len(images_to_process) == 0:
+ return sa_predictions
+
+ pose_runner, detector_runner = get_inference_runners(
+ loader.model_cfg,
+ snapshot_path=model_snapshot_path,
+ max_individuals=max_individuals,
+ num_bodyparts=len(loader.model_cfg["metadata"]["bodyparts"]),
+ num_unique_bodyparts=len(loader.model_cfg["metadata"]["unique_bodyparts"]),
+ device=device,
+ detector_path=detector_snapshot_path,
+ )
+
+ # FIXME(niels, yeshaokai) - Use the detector to combine GT-keypoint created bounding
+ # boxes and predicted bounding boxes - keep the larger of the two
+ # bbox_predictions = detector_runner.inference(images=images_to_process)
+ pose_inputs = [
+ (
+ str(loader.project_path / Path(image)),
+ {"bboxes": np.array(bboxes[image])}
+ )
+ for image in images_to_process
+ ]
+ predictions = pose_runner.inference(pose_inputs)
+
+ for image, prediction in zip(images_to_process, predictions):
+ sa_predictions[image] = prediction
+
+ # save the updated SuperAnimal predictions
+ json_sa_predictions = {
+ image: {
+ "bodyparts": predictions["bodyparts"].tolist(),
+ "bboxes": predictions["bboxes"].tolist(),
+ # "bbox_scores": predictions["bbox_scores"].tolist(),
+ }
+ for image, predictions in sa_predictions.items()
+ }
+ with open(predictions_file, "w") as f:
+ json.dump(json_sa_predictions, f, indent=2)
+
+ return sa_predictions
+
+
+# this is reading from a coco project
+def prepare_memory_replay_dataset(
+ loader: DLCLoader,
+ source_dataset_folder: str | Path,
+ superanimal_name: str,
+ model_snapshot_path: str,
+ detector_snapshot_path: str,
+ max_individuals: int = 1,
+ train_file: str = "train.json",
+ pose_threshold: float = 0.0,
+ device: str | None = None,
+):
+ """
+ Need to first run inference on the source project train file
+ """
+ project_root = loader.project_path.resolve()
+ source_dataset_folder = Path(source_dataset_folder).resolve()
+
+ # Contains the ground truth annotations for the DeepLabCut project
+ # .../dlc-models-pytorch/.../...shuffle0/train/memory_replay/annotations/train.json
+ with open(source_dataset_folder / "annotations" / train_file, "r") as f:
+ project_gt = json.load(f)
+
+ # parse the GT so that image paths are in the format (no matter the OS):
+ # "labeled-data/{video_name}/{image_name}"
+ for image in project_gt["images"]:
+ image["file_name"] = "/".join(Path(image["file_name"]).parts[-3:])
+
+ image_id_to_name = {}
+ image_id_to_annotations = defaultdict(list)
+
+ image_name_to_id = {}
+ image_name_to_gt = defaultdict(list)
+ image_name_to_bbox = defaultdict(list)
+
+ for image in project_gt["images"]:
+ image_name_to_id[image["file_name"]] = image["id"]
+ image_id_to_name[image["id"]] = image["file_name"]
+
+ for anno in project_gt["annotations"]:
+ name = image_id_to_name[anno["image_id"]]
+ image_name_to_gt[name].append(anno)
+ image_name_to_bbox[name].append(anno["bbox"])
+
+ image_ids = list(image_name_to_id.values())
+ for annotation in project_gt["annotations"]:
+ image_id = annotation["image_id"]
+ if annotation["image_id"] in image_ids:
+ image_id_to_annotations[image_id].append(annotation)
+
+ image_name_to_prediction = get_pose_predictions(
+ loader=loader,
+ images=[image["file_name"] for image in project_gt["images"]],
+ bboxes=image_name_to_bbox,
+ superanimal_name=superanimal_name,
+ model_snapshot_path=model_snapshot_path,
+ detector_snapshot_path=detector_snapshot_path,
+ max_individuals=max_individuals,
+ device=device,
+ )
+
+ def xywh2xyxy(bbox):
+ temp_bbox = np.copy(bbox)
+ temp_bbox[2:] = temp_bbox[:2] + temp_bbox[2:]
+ return temp_bbox
+
+ def optimal_match(gts_list, preds_list):
+ arranged_preds_list = []
+ num_gts = len(gts_list)
+ num_preds = len(preds_list)
+ cost_matrix = np.zeros((num_gts, num_preds))
+
+ for i in range(num_gts):
+ for j in range(num_preds):
+ cost_matrix[i, j] = distance.euclidean(
+ gts_list[i][..., :2].flatten(), preds_list[j][..., :2].flatten()
+ )
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+
+ return col_ind
+
+ num_bodyparts = len(project_gt["categories"][0]["keypoints"])
+ for image_name, gts in image_name_to_gt.items():
+ bbox_gts = [np.array(gt["bbox"]) for gt in gts]
+ bbox_gts = [xywh2xyxy(e) for e in bbox_gts]
+ prediction = image_name_to_prediction[image_name]
+ bbox_preds = [xywh2xyxy(pred) for pred in prediction["bboxes"]]
+ optimal_pred_indices = optimal_match(bbox_gts, bbox_preds)
+
+ for idx in range(len(bbox_gts)):
+ if idx == len(optimal_pred_indices):
+ break
+
+ optimal_index = optimal_pred_indices[idx]
+ matched_gt = np.array(gts[idx]["keypoints"])
+ matched_pred = prediction["bodyparts"][optimal_index]
+ bbox_gt = bbox_gts[idx]
+ bbox_pred = bbox_preds[idx]
+
+ # maybe check iou of two bbox
+ iou = calculate_iou(bbox_gt, bbox_pred)
+ if iou < 0.7:
+ matched_gt = np.ones_like(matched_gt) * -1
+ gts[idx]["keypoints"] = list(matched_gt.flatten())
+ else:
+ matched_gt = matched_gt.reshape(num_bodyparts, -1)
+ matched_pred = matched_pred.reshape(num_bodyparts, -1)
+ mask = matched_gt == -1
+ matched_gt[mask] = matched_pred[mask]
+ # after the mixing, we don't care about confidence anymore
+
+ for kpt_idx in range(len(matched_gt)):
+ if 0 < matched_gt[kpt_idx][2] < pose_threshold:
+ matched_gt[kpt_idx][2] = -1
+ elif matched_gt[kpt_idx][2] > 0:
+ matched_gt[kpt_idx][2] = 2
+
+ gts[idx]["keypoints"] = list(matched_gt.flatten())
+
+ # memory replay path
+ memory_replay_train_file_path = os.path.join(
+ source_dataset_folder, "annotations", "memory_replay_train.json"
+ )
+
+ # parse the GT to put the image paths back into OS-specific format
+ for image in project_gt["images"]:
+ image_rel_path = image["file_name"].split("/")
+ image["file_name"] = str(project_root.resolve() / Path(*image_rel_path))
+
+ with open(memory_replay_train_file_path, "w") as f:
+ json.dump(project_gt, f, indent=4)
+
+
+def prepare_memory_replay(
+ config: str | Path,
+ loader: DLCLoader,
+ superanimal_name: str,
+ model_snapshot_path: str | Path,
+ detector_snapshot_path: str | Path,
+ device: str,
+ max_individuals: int = 3,
+ train_file: str = "train.json",
+ pose_threshold: float = 0.1,
+) -> None:
+ """Prepares a shuffle to be trained with memory replay.
+
+ To be trained using memory replay, predictions must be made on all images in the
+ dataset using the SuperAnimal model. Predictions for bodyparts that aren't labeled
+ in the DeepLabCut project are then used as pseudo-labels during training.
+
+ This method will create a COCO-format dataset in the same folder as the
+ ``pytorch_config.yaml`` (the model folder).
+
+ Args:
+ config: Path to the DeepLabCut project configuration file.
+ loader: The loader used to load the training/test data on which a model will
+ be fine-tuned with memory replay.
+ superanimal_name: The name of the SuperAnimal model that is being fine-tuned.
+ model_snapshot_path: Path to the SuperAnimal pose snapshot to fine-tune.
+ detector_snapshot_path: Path to the SuperAnimal detector snapshot to fine-tune.
+ device: Device to use to run inference using the SuperAnimal model.
+ max_individuals: Maximum number of animals that can be present in a frame.
+ train_file: Name of the file containing train annotations (e.g. `train.json`).
+ pose_threshold: The minimum score for a prediction to be used as a pseudo-label.
+ """
+ cfg = af.read_config(config)
+ super_animal_cfg = af.read_plainconfig(
+ get_super_animal_project_config_path(super_animal=superanimal_name)
+ )
+
+ if "individuals" in cfg:
+ temp_dataset = MaDLCPoseDataset(
+ str(loader.project_path), "temp_dataset", shuffle=loader.shuffle
+ )
+ else:
+ temp_dataset = SingleDLCPoseDataset(
+ str(loader.project_path), "temp_dataset", shuffle=loader.shuffle
+ )
+
+ memory_replay_folder = loader.model_folder / "memory_replay"
+ temp_dataset.materialize(
+ memory_replay_folder,
+ framework="coco",
+ append_image_id=False,
+ no_image_copy=True, # use the images in the labeled-data folder
+ )
+
+ weight_init_cfg = loader.model_cfg["train_settings"].get("weight_init")
+ if weight_init_cfg is None:
+ raise ValueError(
+ "You can only train models with memory replay when you are fine-tuning a "
+ "SuperAnimal model. Please look at the documentation to see how to create "
+ "a training dataset to fine-tune one of the SuperAnimal models."
+ )
+
+ weight_init = WeightInitialization.from_dict(weight_init_cfg)
+ if not weight_init.with_decoder:
+ raise ValueError(
+ "You can only train models with memory replay when you are fine-tuning a "
+ "SuperAnimal model. Please look at the documentation to see how to create "
+ "a training dataset to fine-tune one of the SuperAnimal models. Ensure "
+ "that a conversion table is specified for your project and that you select"
+ "``with_decoder=True`` for your ``WeightInitialization``."
+ )
+
+ dataset = COCOPoseDataset(memory_replay_folder, "memory_replay_dataset")
+
+ # here we project the original DLC projects to superanimal space and save them into
+ # a coco project format
+ bodyparts = af.get_bodyparts(cfg)
+ sa_bodyparts = af.get_bodyparts(super_animal_cfg)
+ conversion_table = {}
+ for idx, bpt in enumerate(bodyparts):
+ conversion_table[bpt] = sa_bodyparts[weight_init.conversion_array[idx]]
+
+ dataset.project_with_conversion_table(
+ table_path=None,
+ table_dict=dict(
+ master_keypoints=sa_bodyparts,
+ conversion_table=conversion_table,
+ ),
+ )
+
+ dataset.materialize(
+ memory_replay_folder, framework="coco", deepcopy=False, no_image_copy=True,
+ )
+
+ # then in this function, we do pseudo label to match prediction and gts to create
+ # memory-replay dataset that will be named memory_replay_train.json
+ prepare_memory_replay_dataset(
+ loader,
+ memory_replay_folder,
+ superanimal_name,
+ model_snapshot_path,
+ detector_snapshot_path,
+ max_individuals=max_individuals,
+ device=device,
+ train_file=train_file,
+ pose_threshold=pose_threshold,
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/train_from_coco.py b/deeplabcut/pose_estimation_pytorch/modelzoo/train_from_coco.py
new file mode 100644
index 0000000000..7625de89e9
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/train_from_coco.py
@@ -0,0 +1,93 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""File to train a model on a COCO dataset"""
+
+from __future__ import annotations
+
+import copy
+from pathlib import Path
+
+from deeplabcut.pose_estimation_pytorch import COCOLoader, utils
+from deeplabcut.pose_estimation_pytorch.apis.training import train
+from deeplabcut.pose_estimation_pytorch.runners.logger import setup_file_logging
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+def adaptation_train(
+ project_root: str | Path,
+ model_folder: str | Path,
+ train_file: str,
+ test_file: str,
+ model_config_path: str | Path,
+ device: str | None,
+ epochs: int | None,
+ save_epochs: int | None,
+ detector_epochs: int | None,
+ detector_save_epochs: int | None,
+ snapshot_path: str | None,
+ detector_path: str | None,
+ batch_size: int = 8,
+ detector_batch_size: int = 8,
+ eval_interval: int | None = None,
+):
+ setup_file_logging(Path(model_folder) / "log.txt")
+ loader = COCOLoader(
+ project_root=project_root,
+ model_config_path=model_config_path,
+ train_json_filename=train_file,
+ test_json_filename=test_file,
+ )
+
+ utils.fix_seeds(loader.model_cfg["train_settings"]["seed"])
+
+ updates = {
+ "detector.model.freeze_bn_stats": True,
+ "detector.runner.snapshots.max_snapshots": 5,
+ "detector.runner.snapshots.save_epochs": detector_save_epochs or 1,
+ "detector.train_settings.batch_size": detector_batch_size,
+ "detector.train_settings.epochs": detector_epochs or 4,
+ "model.backbone.freeze_bn_stats": True,
+ "runner.snapshots.max_snapshots": 5,
+ "runner.snapshots.save_epochs": save_epochs or 1,
+ "train_settings.batch_size": batch_size,
+ "train_settings.epochs": epochs or 4,
+ }
+
+ if eval_interval is not None:
+ updates["runner.eval_interval"] = eval_interval
+
+ loader.update_model_cfg(updates)
+
+ pose_task = Task(loader.model_cfg["method"])
+ if pose_task == Task.TOP_DOWN:
+ logger_config = None
+ if loader.model_cfg.get("logger"):
+ logger_config = copy.deepcopy(loader.model_cfg["logger"])
+ logger_config["run_name"] += "-detector"
+
+ if loader.model_cfg["detector"]["train_settings"]["epochs"] > 0:
+ train(
+ loader=loader,
+ run_config=loader.model_cfg["detector"],
+ task=Task.DETECT,
+ device=device,
+ logger_config=logger_config,
+ snapshot_path=detector_path,
+ )
+
+ train(
+ loader=loader,
+ run_config=loader.model_cfg,
+ task=pose_task,
+ device=device,
+ logger_config=loader.model_cfg.get("logger"),
+ snapshot_path=snapshot_path,
+ )
diff --git a/deeplabcut/pose_estimation_pytorch/modelzoo/utils.py b/deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
new file mode 100644
index 0000000000..9138c22b0a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/modelzoo/utils.py
@@ -0,0 +1,203 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import inspect
+import subprocess
+import warnings
+from pathlib import Path
+
+import torch
+from dlclibrary import download_huggingface_model
+
+import deeplabcut.pose_estimation_pytorch.config.utils as config_utils
+from deeplabcut.core.config import read_config_as_dict
+from deeplabcut.pose_estimation_pytorch.config.make_pose_config import add_metadata
+from deeplabcut.utils import auxiliaryfunctions
+
+
+def get_model_configs_folder_path() -> Path:
+ """Returns: the folder containing the SuperAnimal model configuration files"""
+ return Path(auxiliaryfunctions.get_deeplabcut_path()) / "modelzoo" / "model_configs"
+
+
+def get_project_configs_folder_path() -> Path:
+ """Returns: the folder containing the SuperAnimal project configuration files"""
+ return (
+ Path(auxiliaryfunctions.get_deeplabcut_path()) / "modelzoo" / "project_configs"
+ )
+
+
+def get_snapshot_folder_path() -> Path:
+ """Returns: the path to the folder containing the SuperAnimal model snapshots"""
+ return Path(auxiliaryfunctions.get_deeplabcut_path()) / "modelzoo" / "checkpoints"
+
+
+def get_super_animal_model_config_path(model_name: str) -> Path:
+ """Gets the path to the configuration file for a SuperAnimal model.
+
+ Args:
+ model_name: The name of the model for which to get the path.
+
+ Returns:
+ The path to the config file for a SuperAnimal model.
+ """
+ return get_model_configs_folder_path() / f"{model_name}.yaml"
+
+
+def get_super_animal_project_config_path(super_animal: str) -> Path:
+ """Gets the path to a SuperAnimal project configuration file.
+
+ Args:
+ super_animal: The name of the SuperAnimal for which to get the config path.
+
+ Returns:
+ The path to the config file for a SuperAnimal project.
+ """
+ return get_project_configs_folder_path() / f"{super_animal}.yaml"
+
+
+def get_super_animal_snapshot_path(
+ dataset: str,
+ model_name: str,
+ download: bool = True,
+) -> Path:
+ """Gets the path to the snapshot containing SuperAnimal model weights.
+
+ Args:
+ dataset: The name of the SuperAnimal dataset.
+ model_name: The name of the model.
+ download: Whether to download the weights if they aren't already there.
+
+ Returns:
+ The path to the weights for a SuperAnimal model.
+ """
+ model_path = get_snapshot_folder_path() / f"{dataset}_{model_name}.pt"
+ if download and not model_path.exists():
+ download_super_animal_snapshot(dataset, model_name)
+
+ return model_path
+
+
+def load_super_animal_config(
+ super_animal: str,
+ model_name: str,
+ detector_name: str | None = None,
+ max_individuals: int = 30,
+ device: str | None = None,
+) -> dict:
+ """Loads the model configuration file for a model, detector and SuperAnimal
+
+ Args:
+ super_animal: The name of the SuperAnimal for which to create the model config.
+ model_name: The name of the model for which to create the model config.
+ detector_name: The name of the detector for which to create the model config.
+ max_individuals: The maximum number of detections to make in an image
+ device: The device to use to train/run inference on the model
+
+ Returns:
+ The model configuration for a SuperAnimal-pretrained model.
+ """
+ project_cfg_path = get_super_animal_project_config_path(super_animal=super_animal)
+ project_config = read_config_as_dict(project_cfg_path)
+
+ model_cfg_path = get_super_animal_model_config_path(model_name=model_name)
+ model_config = read_config_as_dict(model_cfg_path)
+ model_config = add_metadata(project_config, model_config, model_cfg_path)
+ model_config = update_config(model_config, max_individuals, device)
+
+ if detector_name is None:
+ model_config["method"] = "BU"
+ else:
+ detector_cfg_path = get_super_animal_model_config_path(model_name=detector_name)
+ detector_cfg = read_config_as_dict(detector_cfg_path)
+ model_config["method"] = "TD"
+ model_config["detector"] = detector_cfg
+ return model_config
+
+
+def download_super_animal_snapshot(dataset: str, model_name: str) -> Path:
+ """Downloads a SuperAnimal snapshot
+
+ Args:
+ dataset: The name of the SuperAnimal dataset for which to download a snapshot.
+ model_name: The name of the model for which to download a snapshot.
+
+ Returns:
+ The path to the downloaded snapshot.
+
+ Raises:
+ RuntimeError if the model fails to download.
+ """
+ snapshot_dir = get_snapshot_folder_path()
+ model_name = f"{dataset}_{model_name}"
+ model_path = snapshot_dir / f"{model_name}.pt"
+
+ download_huggingface_model(model_name, target_dir=str(snapshot_dir))
+ if not model_path.exists():
+ raise RuntimeError(f"Failed to download {model_name} to {model_path}")
+
+ return snapshot_dir / f"{model_name}.pt"
+
+
+def get_gpu_memory_map():
+ """Get the current gpu usage."""
+ result = subprocess.check_output(
+ ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,nounits,noheader"],
+ encoding="utf-8",
+ )
+ gpu_memory = [int(x) for x in result.strip().split("\n")]
+ gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
+
+ return gpu_memory_map
+
+
+def select_device():
+ if torch.cuda.is_available():
+ return torch.device(f"cuda:0")
+ else:
+ return torch.device("cpu")
+
+
+def raise_warning_if_called_directly():
+ current_frame = inspect.currentframe()
+ caller_frame = inspect.getouterframes(current_frame, 2)
+ caller_name = caller_frame[1].filename
+
+ if not "pose_estimation_" in caller_name:
+ warnings.warn(
+ f"{caller_name} is intended for internal use only and should not be called directly.",
+ UserWarning,
+ )
+
+
+def update_config(config: dict, max_individuals: int, device: str):
+ """Loads the model configuration file for a model, detector and SuperAnimal
+
+ Args:
+ config: The default model configuration file.
+ max_individuals: The maximum number of detections to make in an image
+ device: The device to use to train/run inference on the model
+
+ Returns:
+ The model configuration for a SuperAnimal-pretrained model.
+ """
+ config = config_utils.replace_default_values(
+ config,
+ num_bodyparts=len(config["metadata"]["bodyparts"]),
+ num_individuals=max_individuals,
+ backbone_output_channels=config["model"]["backbone_output_channels"],
+ )
+ config["metadata"]["individuals"] = [f"animal{i}" for i in range(max_individuals)]
+
+ config["device"] = device
+ if "detector" in config:
+ config["detector"]["device"] = device
+
+ return config
diff --git a/deeplabcut/pose_estimation_pytorch/post_processing/__init__.py b/deeplabcut/pose_estimation_pytorch/post_processing/__init__.py
new file mode 100644
index 0000000000..25f07ebbf1
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/post_processing/__init__.py
@@ -0,0 +1,14 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.post_processing.match_predictions_to_gt import (
+ oks_match_prediction_to_gt,
+ rmse_match_prediction_to_gt,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/post_processing/identity.py b/deeplabcut/pose_estimation_pytorch/post_processing/identity.py
new file mode 100644
index 0000000000..9f81ab4619
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/post_processing/identity.py
@@ -0,0 +1,46 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Functions to assign identity to predictions from an identity head"""
+from __future__ import annotations
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+
+
+def assign_identity(
+ predictions: np.ndarray, identity_scores: np.ndarray
+) -> np.ndarray:
+ """
+ Args:
+ predictions: Pose predictions for an image, with shape (num_individuals,
+ num_bodyparts, 3)
+ identity_scores: Identity predictions for keypoints in an image, of shape
+ (num_individuals, num_bodyparts, num_individuals).
+
+ Returns:
+ The ordering to use to match predictions to identities.
+ """
+ if not len(predictions) == len(identity_scores):
+ raise ValueError(
+ "There are not the same number of predictions as identity scores"
+ f" ({len(predictions)} != {len(identity_scores)}"
+ )
+
+ # average of ID scores, weighted by keypoint confidence
+ pose_conf = predictions[:, :, 2:3]
+ cost_matrix = np.mean(pose_conf * identity_scores, axis=1)
+
+ row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True)
+ new_order = np.zeros_like(row_ind)
+ for old_pos, new_pos in zip(row_ind, col_ind):
+ new_order[new_pos] = old_pos
+
+ return new_order
diff --git a/deeplabcut/pose_estimation_pytorch/post_processing/match_predictions_to_gt.py b/deeplabcut/pose_estimation_pytorch/post_processing/match_predictions_to_gt.py
new file mode 100644
index 0000000000..7113303e3c
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/post_processing/match_predictions_to_gt.py
@@ -0,0 +1,190 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+
+from deeplabcut.core.inferenceutils import (
+ calc_object_keypoint_similarity,
+)
+
+
+def rmse_match_prediction_to_gt(
+ pred_kpts: np.ndarray, gt_kpts: np.ndarray
+) -> np.ndarray:
+ """
+ Hungarian algorithm predicted individuals to ground truth ones, using root mean
+ squared error (rmse). The function provides a way to match predicted individuals to
+ ground truth individuals based on the rmse distance between their corresponding
+ keypoints. This algorithm is used to find the optimal matching, taking into account
+ the potential missing animal.
+
+ Raises:
+ ValueError: if `gt_kpts.shape != pred_kpts.shape`
+
+ Args:
+ pred_kpts: shape (num_predictions, num_keypoints, 3), ground truth keypoints for
+ an image, where the 3 values are (x,y,score) for each keypoint
+ gt_kpts: shape (num_individuals, num_keypoints, 3), ground truth keypoints for
+ an image, where the 3 values are (x,y,visibility) for each keypoint
+
+ Returns:
+ col_ind: array of the individuals indices for prediction
+ """
+ num_pred, num_keypoints, _ = pred_kpts.shape
+ num_idv, num_keypoints_gt, _ = gt_kpts.shape
+ if num_keypoints + 1 == num_keypoints_gt:
+ gt_kpts = gt_kpts[:, :-1, :].copy()
+ elif num_keypoints == num_keypoints_gt:
+ gt_kpts = gt_kpts.copy()
+ else:
+ raise ValueError("Shape mismatch between ground truth and predictions")
+
+ valid_gt = np.any(gt_kpts[..., 2] > 0, axis=1)
+ valid_gt_indices = np.nonzero(valid_gt)[0]
+ if len(valid_gt_indices) == 0:
+ return np.arange(num_idv)
+
+ valid_pred = np.any(pred_kpts[..., 2] > 0, axis=1)
+ valid_pred_indices = np.nonzero(valid_pred)[0]
+ if len(valid_pred_indices) == 0:
+ return np.arange(num_idv)
+
+ distance_matrix = np.full((len(valid_gt_indices), len(valid_pred_indices)), np.nan)
+ for i, gt_idx in enumerate(valid_gt_indices):
+ gt_idv = gt_kpts[gt_idx]
+ mask = gt_idv[:, 2] > 0
+ for j, pred_idx in enumerate(valid_pred_indices):
+ pred_idv = pred_kpts[pred_idx]
+ d = (gt_idv[mask, :2] - pred_idv[mask, :2]) ** 2
+ if np.any(~np.isnan(d)):
+ distance_matrix[i, j] = np.nanmean(d)
+
+ if np.all(np.isnan(distance_matrix)):
+ return np.arange(num_idv)
+
+ # np.inf and np.nan in linear_sum_assigment raises error; so when a prediction
+ # cannot be assigned to a ground truth (e.g. with PAFs, where predicted bodyparts
+ # can be NaN) set the distance to a distance greater than the maximum distance
+ max_dist = np.nanmax(distance_matrix)
+ distance_matrix = np.nan_to_num(distance_matrix, nan=100 * max_dist)
+ _, col_ind = linear_sum_assignment(distance_matrix) # len == len(valid_gt_indices)
+
+ gt_idx_to_pred_idx = {
+ valid_gt_indices[valid_gt_index]: valid_pred_indices[valid_pred_index]
+ for valid_gt_index, valid_pred_index in enumerate(col_ind)
+ }
+ matched_pred = {valid_pred_indices[i] for i in col_ind}
+ unmatched_pred = [i for i in range(num_idv) if i not in matched_pred]
+ next_unmatched = 0
+ col_ind = []
+ for gt_index in range(num_idv):
+ if gt_index in gt_idx_to_pred_idx:
+ col_ind.append(gt_idx_to_pred_idx[gt_index])
+ else:
+ col_ind.append(unmatched_pred[next_unmatched])
+ next_unmatched += 1
+
+ return np.array(col_ind)
+
+
+def oks_match_prediction_to_gt(
+ pred_kpts: np.array, gt_kpts: np.array, individual_names: list
+) -> np.array:
+ """Summary:
+ Hungarian algorithm predicted individuals to ground truth ones, using object keypoint similarity (oks).
+ Oks measures the accuracy of predicted keypoints compared to ground truth keypoints.
+ More information about oks can be found in cocodataset (https://cocodataset.org/#keypoints-eval).
+
+ Args:
+ pred_kpts: Predicted keypoints for each animal. The shape of the array is (num_animals, num_keypoints, 3):
+ num_animals: Number of animals.
+ num_keypoints: Number of keypoints.
+ 3: (x, y, score) coordinates of each keypoint.
+ gt_kpts: Ground truth keypoints for each animal. The shape of the array is (num_animals, num_keypoints(+1 if with center), 2):
+ num_animals: Number of animals.
+ num_keypoints: Number of keypoints.
+ individual_names: names of individuals
+
+ Returns:
+ col_ind: Array of the individual indexes for prediction.
+
+ Examples:
+ input:
+ pred_kpts = np.array(...)
+ gt_kpts = np.array(...)
+ individual_names = [...]
+ output:
+ col_ind = np.array([...])
+ """
+
+ num_animals, num_keypoints, _ = pred_kpts.shape
+ if num_keypoints + 1 == gt_kpts.shape[1]:
+ gt_kpts_without_ctr = gt_kpts[:, :-1, :].copy()
+ elif num_keypoints == gt_kpts.shape[1]:
+ gt_kpts_without_ctr = gt_kpts.copy()
+ else:
+ raise ValueError("Shape mismatch between ground truth and predictions")
+
+ # Computation of the number of annotated animals in the ground truth
+ num_animals_gt = num_animals
+ for animal_index in range(num_animals):
+ if (gt_kpts_without_ctr[animal_index] < 0).all():
+ num_animals_gt -= 1
+
+ oks_matrix = np.zeros((num_animals_gt, num_animals))
+ gt_kpts_without_ctr[
+ gt_kpts_without_ctr < 0
+ ] = np.nan # non visible keypoints should be nan to use calc_oks
+ idx_gt = -1
+ for g in range(num_animals):
+ if np.isnan(gt_kpts_without_ctr[g]).all():
+ continue
+ else:
+ idx_gt += 1
+ for p in range(num_animals):
+ oks_matrix[idx_gt, p] = calc_object_keypoint_similarity(
+ pred_kpts[p, :, :2],
+ gt_kpts_without_ctr[g],
+ 0.1,
+ margin=0,
+ symmetric_kpts=None, # TODO take into account symmetric keypoints
+ )
+
+ row_ind, col_ind = linear_sum_assignment(oks_matrix, maximize=True)
+ # if animals are missing in the frame, the predictions corresponding to nothing are not shuffled
+ col_ind = extend_col_ind(col_ind, num_animals)
+
+ return col_ind
+
+
+def extend_col_ind(col_ind: np.array, num_animals: int) -> np.array:
+ """Summary:
+ Extends the column indices of a 1D array, col_ind, by adding any missing column indices from 0 to num_animals-1.
+
+ Args:
+ col_ind: 1D array of column indices
+ num_animals: total number of animals
+
+ Returns:
+ extended_array: extended 1D array of column indices
+
+ Examples:
+ input:
+ col_ind =
+ num_animals = 5
+ output:
+ extended_array =
+ """
+ existing_cols = set(col_ind) # Convert the array to a set for faster lookup
+ missing_cols = [num for num in range(num_animals) if num not in existing_cols]
+ extended_array = np.concatenate((col_ind, missing_cols)).astype(int)
+ return extended_array
diff --git a/deeplabcut/pose_estimation_pytorch/registry.py b/deeplabcut/pose_estimation_pytorch/registry.py
new file mode 100644
index 0000000000..0cc91ac0be
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/registry.py
@@ -0,0 +1,333 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import inspect
+from functools import partial
+from typing import Any, Dict, Optional
+
+
+def build_from_cfg(
+ cfg: Dict, registry: "Registry", default_args: Optional[Dict] = None
+) -> Any:
+ """Builds a module from the configuration dictionary when it represents a class configuration,
+ or call a function from the configuration dictionary when it represents a function configuration.
+
+ Args:
+ cfg: Configuration dictionary. It should at least contain the key "type".
+ registry: The registry to search the type from.
+ default_args: Default initialization arguments.
+ Defaults to None.
+
+ Returns:
+ Any: The constructed object.
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry, build_from_cfg
+ >>> class Model:
+ >>> def __init__(self, param):
+ >>> self.param = param
+ >>> cfg = {"type": "Model", "param": 10}
+ >>> registry = Registry("models")
+ >>> registry.register_module(Model)
+ >>> obj = build_from_cfg(cfg, registry)
+ >>> assert isinstance(obj, Model)
+ >>> assert obj.param == 10
+ """
+
+ args = cfg.copy()
+
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ obj_type = args.pop("type")
+ if isinstance(obj_type, str):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(f"{obj_type} is not in the {registry.name} registry")
+ elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(f"type must be a str or valid type, but got {type(obj_type)}")
+ try:
+ return obj_cls(**args)
+ except Exception as e:
+ # Normal TypeError does not print class name.
+ raise type(e)(f"{obj_cls.__name__}: {e}")
+
+
+class Registry:
+ """A registry to map strings to classes or functions.
+ Registered objects could be built from the registry. Meanwhile, registered
+ functions could be called from the registry.
+
+ Args:
+ name: Registry name.
+ build_func: Builds function to construct an instance from
+ the Registry. If neither ``parent`` nor
+ ``build_func`` is specified, the ``build_from_cfg``
+ function is used. If ``parent`` is specified and
+ ``build_func`` is not given, ``build_func`` will be
+ inherited from ``parent``. Default: None.
+ parent: Parent registry. The class registered in
+ children's registry could be built from the parent.
+ Default: None.
+ scope: The scope of the registry. It is the key to search
+ for children's registry. If not specified, scope will be the
+ name of the package where the class is defined, e.g. mmdet, mmcls, mmseg.
+ Default: None.
+
+ Attributes:
+ name: Registry name.
+ module_dict: The dictionary containing registered modules.
+ children: The dictionary containing children registries.
+ scope: The scope of the registry.
+ """
+
+ def __init__(self, name, build_func=None, parent=None, scope=None):
+ self._name = name
+ self._module_dict = dict()
+ self._children = dict()
+ self._scope = "."
+
+ if build_func is None:
+ if parent is not None:
+ self.build_func = parent.build_func
+ else:
+ self.build_func = build_from_cfg
+ else:
+ self.build_func = build_func
+ if parent is not None:
+ assert isinstance(parent, Registry)
+ parent._add_children(self)
+ self.parent = parent
+ else:
+ self.parent = None
+
+ def __len__(self):
+ return len(self._module_dict)
+
+ def __contains__(self, key):
+ return self.get(key) is not None
+
+ def __repr__(self):
+ format_str = (
+ self.__class__.__name__ + f"(name={self._name}, "
+ f"items={self._module_dict})"
+ )
+ return format_str
+
+ @staticmethod
+ def split_scope_key(key):
+ """Split scope and key.
+ The first scope will be split from key.
+ Examples:
+ >>> Registry.split_scope_key('mmdet.ResNet')
+ 'mmdet', 'ResNet'
+ >>> Registry.split_scope_key('ResNet')
+ None, 'ResNet'
+ Return:
+ tuple[str | None, str]: The former element is the first scope of
+ the key, which can be ``None``. The latter is the remaining key.
+ """
+ split_index = key.find(".")
+ if split_index != -1:
+ return key[:split_index], key[split_index + 1 :]
+ else:
+ return None, key
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def scope(self):
+ return self._scope
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ @property
+ def children(self):
+ return self._children
+
+ def get(self, key):
+ """Get the registry record.
+
+ Args:
+ key: The class name in string format.
+
+ Returns:
+ class: The corresponding class.
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry
+ >>> registry = Registry("models")
+ >>> class Model:
+ >>> pass
+ >>> registry.register_module(Model, "Model")
+ >>> assert registry.get("Model") == Model
+ """
+ scope, real_key = self.split_scope_key(key)
+ if scope is None or scope == self._scope:
+ # get from self
+ if real_key in self._module_dict:
+ return self._module_dict[real_key]
+ else:
+ # get from self._children
+ if scope in self._children:
+ return self._children[scope].get(real_key)
+ else:
+ # goto root
+ parent = self.parent
+ while parent.parent is not None:
+ parent = parent.parent
+ return parent.get(key)
+
+ def build(self, *args, **kwargs):
+ """Builds an instance from the registry.
+
+ Args:
+ *args: Arguments passed to the build function.
+ **kwargs: Keyword arguments passed to the build function.
+
+ Returns:
+ Any: The constructed object.
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry, build_from_cfg
+ >>> class Model:
+ >>> def __init__(self, param):
+ >>> self.param = param
+ >>> cfg = {"type": "Model", "param": 10}
+ >>> registry = Registry("models")
+ >>> registry.register_module(Model)
+ >>> obj = registry.build(cfg, param=20)
+ >>> assert isinstance(obj, Model)
+ >>> assert obj.param == 20
+ """
+ return self.build_func(*args, **kwargs, registry=self)
+
+ def _add_children(self, registry):
+ """Add children for a registry.
+
+ Args:
+ registry: The registry to be added as children based on its scope.
+
+ Returns:
+ None
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry
+ >>> models = Registry('models')
+ >>> mmdet_models = Registry('models', parent=models)
+ >>> class Model:
+ >>> pass
+ >>> mmdet_models.register_module(Model)
+ >>> obj = models.build(dict(type='mmdet.Model'))
+ >>> assert isinstance(obj, Model)
+ """
+ assert isinstance(registry, Registry)
+ assert registry.scope is not None
+ assert (
+ registry.scope not in self.children
+ ), f"scope {registry.scope} exists in {self.name} registry"
+ self.children[registry.scope] = registry
+
+ def _register_module(self, module, module_name=None, force=False):
+ """Register a module.
+
+ Args:
+ module: Module class or function to be registered.
+ module_name: The module name(s) to be registered.
+ If not specified, the class name will be used.
+ force: Whether to override an existing class with the same name.
+ Default: False.
+
+ Returns:
+ None
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry
+ >>> registry = Registry("models")
+ >>> class Model:
+ >>> pass
+ >>> registry._register_module(Model, "Model")
+ >>> assert registry.get("Model") == Model
+ """
+ if not inspect.isclass(module) and not inspect.isfunction(module):
+ raise TypeError(
+ "module must be a class or a function, " f"but got {type(module)}"
+ )
+
+ if module_name is None:
+ module_name = module.__name__
+ if isinstance(module_name, str):
+ module_name = [module_name]
+ for name in module_name:
+ if not force and name in self._module_dict:
+ raise KeyError(f"{name} is already registered " f"in {self.name}")
+ self._module_dict[name] = module
+
+ def deprecated_register_module(self, cls=None, force=False):
+ """Decorator to register a class in the registry.
+
+ Args:
+ cls: The class to be registered.
+ force: Whether to override an existing class with the same name.
+ Default: False.
+
+ Returns:
+ type: The input class.
+
+ Example:
+ >>> from deeplabcut.pose_estimation_pytorch.registry import Registry
+ >>> registry = Registry("models")
+ >>> @registry.deprecated_register_module()
+ >>> class Model:
+ >>> pass
+ >>> assert registry.get("Model") == Model
+ """
+ if cls is None:
+ return partial(self.deprecated_register_module, force=force)
+ self._register_module(cls, force=force)
+ return cls
+
+ def register_module(self, name=None, force=False, module=None):
+ """Register a module.
+ A record will be added to `self._module_dict`, whose key is the class
+ name or the specified name, and value is the class itself.
+ It can be used as a decorator or a normal function.
+ Args:
+ name: The module name to be registered. If not
+ specified, the class name will be used.
+ force: Whether to override an existing class with
+ the same name. Default: False.
+ module: Module class or function to be registered.
+ """
+ if not isinstance(force, bool):
+ raise TypeError(f"force must be a boolean, but got {type(force)}")
+ # NOTE: This is a walkaround to be compatible with the old api,
+ # while it may introduce unexpected bugs.
+ if isinstance(name, type):
+ return self.deprecated_register_module(name, force=force)
+
+ # use it as a normal method: x.register_module(module=SomeClass)
+ if module is not None:
+ self._register_module(module=module, module_name=name, force=force)
+ return module
+
+ # use it as a decorator: @x.register_module()
+ def _register(module):
+ self._register_module(module=module, module_name=name, force=force)
+ return module
+
+ return
diff --git a/deeplabcut/pose_estimation_pytorch/runners/__init__.py b/deeplabcut/pose_estimation_pytorch/runners/__init__.py
new file mode 100644
index 0000000000..e5a5922df6
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/__init__.py
@@ -0,0 +1,33 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+from deeplabcut.pose_estimation_pytorch.runners.base import (
+ attempt_snapshot_load,
+ get_load_weights_only,
+ fix_snapshot_metadata,
+ Runner,
+ set_load_weights_only,
+)
+from deeplabcut.pose_estimation_pytorch.runners.dynamic_cropping import DynamicCropper
+from deeplabcut.pose_estimation_pytorch.runners.inference import (
+ build_inference_runner,
+ DetectorInferenceRunner,
+ InferenceRunner,
+ PoseInferenceRunner,
+)
+from deeplabcut.pose_estimation_pytorch.runners.logger import LOGGER
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import TorchSnapshotManager
+from deeplabcut.pose_estimation_pytorch.runners.train import (
+ build_training_runner,
+ DetectorTrainingRunner,
+ PoseTrainingRunner,
+ TrainingRunner,
+)
diff --git a/deeplabcut/pose_estimation_pytorch/runners/base.py b/deeplabcut/pose_estimation_pytorch/runners/base.py
new file mode 100644
index 0000000000..ee8dbd6d1a
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/base.py
@@ -0,0 +1,226 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import logging
+import os
+import pickle
+from abc import ABC
+from pathlib import Path
+from typing import Generic, TypeVar
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+ModelType = TypeVar("ModelType", bound=nn.Module)
+
+_load_weights_only: bool = (
+ os.getenv("TORCH_LOAD_WEIGHTS_ONLY", "true").lower() in ("true", "1")
+)
+
+
+def get_load_weights_only() -> bool:
+ """Gets the default value to use when loading snapshots with `torch.load(...)`.
+
+ Returns:
+ The default `weights_only` value when loading snapshots using `torch.load(...)`.
+ """
+ global _load_weights_only
+ return _load_weights_only
+
+
+def set_load_weights_only(value: bool) -> None:
+ """Sets the default value to use when loading snapshots with `torch.load(...)`.
+
+ Args:
+ value: The default `weights_only` value to use when loading snapshots using
+ `torch.load(...)`.
+ """
+ global _load_weights_only
+ _load_weights_only = value
+
+
+class Runner(ABC, Generic[ModelType]):
+ """Runner base class
+
+ A runner takes a model and runs actions on it, such as training or inference
+ """
+
+ def __init__(
+ self,
+ model: ModelType,
+ device: str = "cpu",
+ gpus: list[int] | None = None,
+ snapshot_path: str | Path | None = None,
+ ):
+ """
+ Args:
+ model: the model to run
+ device: the device to use (e.g. {'cpu', 'cuda:0', 'mps'})
+ gpus: the list of GPU indices to use for multi-GPU training
+ snapshot_path: the path of a snapshot from which to load model weights
+ """
+ if gpus is None:
+ gpus = []
+
+ if len(gpus) == 1:
+ if device != "cuda":
+ raise ValueError(
+ "When specifying a GPU index to train on, the device must be set "
+ f"to 'cuda'. Found {device}"
+ )
+ device = f"cuda:{gpus[0]}"
+
+ self.model = model
+ self.device = device
+ self.snapshot_path = snapshot_path
+ self._gpus = gpus
+ self._data_parallel = len(gpus) > 1
+
+ @staticmethod
+ def load_snapshot(
+ snapshot_path: str | Path,
+ device: str,
+ model: ModelType,
+ weights_only: bool | None = None,
+ ) -> dict:
+ """Loads the state dict for a model from a file
+
+ This method loads a file containing a DeepLabCut PyTorch model snapshot onto
+ a given device, and sets the model weights using the state_dict.
+
+ Args:
+ snapshot_path: The path containing the model weights to load
+ device: The device on which the model should be loaded
+ model: The model for which the weights are loaded
+ weights_only: Value for torch.load() `weights_only` parameter.
+ If False, the python pickle module is used implicitly, which is known to
+ be insecure. Only set to False if you're loading data that you trust
+ (e.g. snapshots that you created yourself). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+
+ Returns:
+ The content of the snapshot file.
+ """
+ snapshot = attempt_snapshot_load(snapshot_path, device, weights_only)
+ model.load_state_dict(snapshot["model"])
+ return snapshot
+
+
+def attempt_snapshot_load(
+ path: str | Path,
+ device: str,
+ weights_only: bool | None = None,
+) -> dict:
+ """Attempts to load a snapshot using `torch.load(...)`.
+
+ Args:
+ path: The path of the snapshot to try to load..
+ device: The device to use for the `map_location`.
+ weights_only: Value for torch.load() `weights_only` parameter.
+ If False, the python pickle module is used implicitly, which is known to be
+ insecure. Only set to False if you're loading data that you trust (e.g.
+ snapshots that you created yourself). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+
+ Returns:
+ The loaded snapshot.
+
+ Raises:
+ pickle.UnpicklingError: If `weights_only=True` but the snapshot failed to load
+ with `weights_only=True`.
+ """
+ try:
+ if weights_only is None:
+ weights_only = get_load_weights_only()
+
+ snapshot = torch.load(path, map_location=device, weights_only=weights_only)
+ except pickle.UnpicklingError as err:
+ logging.error(
+ f"\nFailed to load the snapshot: {path}.\n\n"
+ "If you trust the snapshot that you're trying to load, you can try\n"
+ "calling `Runner.load_snapshot` with `weights_only=False`. See the \n"
+ "error message below for more information and warnings.\n"
+ "You can set the `weights_only` parameter in the model configuration (\n"
+ "the content of the pytorch_config.yaml), as:\n\n```\n"
+ "runner:\n"
+ " load_weights_only: False\n```\n\n"
+ "If it's the detector snapshot that's failing to load, place the\n"
+ "`load_weights_only` key under the detector runner:\n\n```\n"
+ "detector:\n"
+ " runner:\n"
+ " load_weights_only: False\n```\n\n"
+ "You can also set the default `load_weights_only` that will be used when\n"
+ "the `load_weights_only` variable is not set in the `pytorch_config.yaml`\n"
+ "using `deeplabcut.pose_estimation_pytorch.set_load_weights_only(value)`:\n"
+ "\n```\n"
+ "from deeplabcut.pose_estimation_pytorch import set_load_weights_only\n"
+ "set_load_weights_only(True)\n"
+ "```\n\n"
+ "You can also set the value for `load_weights_only` with a \n"
+ "`TORCH_LOAD_WEIGHTS_ONLY` environment variable. If you call \n"
+ "`TORCH_LOAD_WEIGHTS_ONLY=False python -m deeplabcut`, it will launch the\n"
+ "DeepLabCut GUI with the default `load_weights_only` value to False.\n"
+ "If you set this value to `False`, make sure you only load snapshots that\n"
+ "you trust.\n\n"
+ )
+ raise err
+
+ return snapshot
+
+
+def fix_snapshot_metadata(path: str | Path) -> None:
+ """Replace numpy floats in snapshot metrics
+
+ Only call this method with snapshots that you trust, as torch.load(...) is called
+ with `weights_only=False`. For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+
+ DeepLabCut PyTorch snapshots trained with older releases may have `numpy` floats in
+ the stored metrics. This method opens the snapshots (with `weights_only=False`),
+ replaces the numpy floats with python floats (allowing to load with
+ `weights_only=True`), and saves the new snapshot data.
+
+ Warning: This overwrites your existing snapshot. If you want to ensure that no data
+ is lost, copy your snapshot before calling `fix_snapshot_metadata`.
+
+ Args:
+ path: The path of the snapshot to fix.
+ """
+ snapshot = torch.load(path, map_location="cpu", weights_only=False)
+ metrics = snapshot.get("metadata", {}).get("metrics")
+ if metrics is not None:
+ snapshot["metadata"]["metrics"] = {k: float(v) for k, v in metrics.items()}
+
+ torch.save(snapshot, path)
+
+
+def _add_numpy_to_torch_safe_globals():
+ """
+ Attempts tot add numpy classes allowing snapshots containing numpy floats in the
+ metrics to be loaded without needing to change the `weights_only` argument.
+
+ This fix only works for `numpy>=1.25.0`.
+ """
+ try:
+ from numpy.core.multiarray import scalar
+ from numpy.dtypes import Float64DType
+ torch.serialization.add_safe_globals([np.dtype, Float64DType, scalar])
+ except Exception:
+ pass
+
+
+_add_numpy_to_torch_safe_globals()
diff --git a/deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py b/deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py
new file mode 100644
index 0000000000..7c7387b279
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/dynamic_cropping.py
@@ -0,0 +1,188 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Modules to dynamically crop individuals out of videos to improve video analysis"""
+import math
+from dataclasses import dataclass, field
+from typing import Optional
+
+import torch
+
+
+@dataclass
+class DynamicCropper:
+ """
+ If the state is true, then dynamic cropping will be performed. That means that
+ if an object is detected (i.e. any body part > detection threshold), then object
+ boundaries are computed according to the smallest/largest x position and
+ smallest/largest y position of all body parts. This window is expanded by the
+ margin and from then on only the posture within this crop is analyzed (until the
+ object is lost, i.e. < detection threshold). The current position is utilized for
+ updating the crop window for the next frame (this is why the margin is important
+ and should be set large enough given the movement of the animal).
+
+ Attributes:
+ threshold: float
+ The threshold score for bodyparts above which an individual is deemed to
+ have been detected.
+ margin: int
+ The margin used to expand an individuals bounding box before cropping it.
+
+ Examples:
+ >>> import deeplabcut.pose_estimation_pytorch.models as models
+ >>>
+ >>> model: models.PoseModel
+ >>> frames: torch.Tensor # shape (num_frames, 3, H, W)
+ >>>
+ >>> dynamic = DynamicCropper(threshold=0.6, margin=25)
+ >>> predictions = []
+ >>> for image in frames:
+ >>> image = dynamic.crop(image)
+ >>>
+ >>> outputs = model(image)
+ >>> preds = model.get_predictions(outputs)
+ >>> pose = preds["bodypart"]["poses"]
+ >>>
+ >>> dynamic.update(pose)
+ >>> predictions.append(pose)
+ >>>
+ """
+ threshold: float
+ margin: int
+ _crop: tuple[int, int, int, int] | None = field(default=None, repr=False)
+ _shape: tuple[int, int] | None = field(default=None, repr=False)
+
+ def crop(self, image: torch.Tensor) -> torch.Tensor:
+ """Crops an input image according to the dynamic cropping parameters.
+
+ Args:
+ image: The image to crop, of shape (1, C, H, W).
+
+ Returns:
+ The cropped image of shape (1, C, H', W'), where [H', W'] is the size of
+ the crop.
+
+ Raises:
+ RuntimeError: if there is not exactly one image in the batch to crop, or if
+ `crop` was previously called with an image of a different width or
+ height.
+ """
+ if len(image) != 1:
+ raise RuntimeError(
+ "DynamicCropper can only be used with batch size 1 (found image "
+ f"shape: {image.shape})"
+ )
+
+ if self._shape is None:
+ self._shape = image.shape[3], image.shape[2]
+
+ if image.shape[3] != self._shape[0] or image.shape[2] != self._shape[1]:
+ raise RuntimeError(
+ "All frames must have the same shape; The first frame had (W, H) "
+ f"{self._shape} but the current frame has shape {image.shape}."
+ )
+
+ if self._crop is None:
+ return image
+
+ x0, y0, x1, y1 = self._crop
+ return image[:, :, y0:y1, x0:x1]
+
+ def update(self, pose: torch.Tensor) -> None:
+ """Updates the dynamic crop according to the pose model output.
+
+ Uses the pose predicted by the model to update the dynamic crop parameters for
+ the next frame. Scales the pose predicted in the cropped image back to the
+ original image space and returns it.
+
+ This method modifies the pose tensor in-place; so pass a copy of the tensor if
+ you need to keep the original values.
+
+ Args:
+ pose: The pose that was predicted by the pose estimation model in the
+ cropped image coordinate space.
+ """
+ if self._shape is None:
+ raise RuntimeError(f"You must call `crop` before calling `update`.")
+
+ # offset the pose to the original image space
+ offset_x, offset_y = 0, 0
+ if self._crop is not None:
+ offset_x, offset_y = self._crop[:2]
+ pose[..., 0] = pose[..., 0] + offset_x
+ pose[..., 1] = pose[..., 1] + offset_y
+
+ # check whether keypoints can be used for dynamic cropping
+ keypoints = pose[..., :3].reshape(-1, 3)
+ keypoints = keypoints[~torch.any(torch.isnan(keypoints), dim=1)]
+ if len(keypoints) == 0:
+ self.reset()
+ return
+
+ mask = keypoints[:, 2] >= self.threshold
+ if torch.all(~mask):
+ self.reset()
+ return
+
+ # set the crop coordinates
+ x0 = self._min_value(keypoints[:, 0], self._shape[0])
+ x1 = self._max_value(keypoints[:, 0], self._shape[0])
+ y0 = self._min_value(keypoints[:, 1], self._shape[1])
+ y1 = self._max_value(keypoints[:, 1], self._shape[1])
+ crop_w, crop_h = x1 - x0, y1 - y0
+ if crop_w == 0 or crop_h == 0:
+ self.reset()
+ return
+
+ self._crop = x0, y0, x1, y1
+
+ def reset(self) -> None:
+ """Resets the DynamicCropper to not crop the next frame"""
+ self._crop = None
+
+ @staticmethod
+ def build(
+ dynamic: bool, threshold: float, margin: int
+ ) -> Optional["DynamicCropper"]:
+ """Builds the DynamicCropper based on the given parameters
+
+ Args:
+ dynamic: Whether dynamic cropping should be used
+ threshold: The threshold score for bodyparts above which an individual is
+ deemed to have been detected.
+ margin: The margin used to expand an individuals bounding box before
+ cropping it.
+
+ Returns:
+ None if dynamic is False
+ DynamicCropper to use if dynamic is True
+ """
+ if not dynamic:
+ return None
+
+ return DynamicCropper(threshold, margin)
+
+ def _min_value(self, coordinates: torch.Tensor, maximum: int) -> int:
+ """Returns: min(coordinates - margin), clipped to [0, maximum]"""
+ return self._clip(
+ int(math.floor(torch.min(coordinates).item() - self.margin)),
+ maximum,
+ )
+
+ def _max_value(self, coordinates: torch.Tensor, maximum: int) -> int:
+ """Returns: max(coordinates + margin), clipped to [0, maximum]"""
+ return self._clip(
+ int(math.ceil(torch.max(coordinates).item() + self.margin)),
+ maximum,
+ )
+
+ def _clip(self, value: int, maximum: int) -> int:
+ """Returns: The value clipped to [0, maximum]"""
+ return min(max(value, 0), maximum)
diff --git a/deeplabcut/pose_estimation_pytorch/runners/inference.py b/deeplabcut/pose_estimation_pytorch/runners/inference.py
new file mode 100644
index 0000000000..c8027b2f3f
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/inference.py
@@ -0,0 +1,387 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Any, Generic, Iterable
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+import deeplabcut.pose_estimation_pytorch.runners.shelving as shelving
+from deeplabcut.pose_estimation_pytorch.data.postprocessor import Postprocessor
+from deeplabcut.pose_estimation_pytorch.data.preprocessor import Preprocessor
+from deeplabcut.pose_estimation_pytorch.models.detectors import BaseDetector
+from deeplabcut.pose_estimation_pytorch.models.model import PoseModel
+from deeplabcut.pose_estimation_pytorch.runners.base import ModelType, Runner
+from deeplabcut.pose_estimation_pytorch.runners.dynamic_cropping import DynamicCropper
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+class InferenceRunner(Runner, Generic[ModelType], metaclass=ABCMeta):
+ """Base class for inference runners
+
+ A runner takes a model and runs actions on it, such as training or inference
+ """
+
+ def __init__(
+ self,
+ model: ModelType,
+ batch_size: int = 1,
+ device: str = "cpu",
+ snapshot_path: str | Path | None = None,
+ preprocessor: Preprocessor | None = None,
+ postprocessor: Postprocessor | None = None,
+ load_weights_only: bool | None = None,
+ ):
+ """
+ Args:
+ model: The model to run actions on
+ device: The device to use (e.g. {'cpu', 'cuda:0', 'mps'})
+ snapshot_path: If defined, the path of a snapshot from which to load
+ pretrained weights
+ preprocessor: The preprocessor to use on images before inference
+ postprocessor: The postprocessor to use on images after inference
+ load_weights_only: Value for the torch.load() `weights_only` parameter.
+ If False, the python pickle module is used implicitly, which is known to
+ be insecure. Only set to False if you're loading data that you trust
+ (e.g. snapshots that you created). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+ """
+ super().__init__(model=model, device=device, snapshot_path=snapshot_path)
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError(f"batch_size must be a positive integer; is {batch_size}")
+
+ self.batch_size = batch_size
+ self.preprocessor = preprocessor
+ self.postprocessor = postprocessor
+
+ if self.snapshot_path is not None and self.snapshot_path != "":
+ self.load_snapshot(
+ self.snapshot_path,
+ self.device,
+ self.model,
+ weights_only=load_weights_only,
+ )
+
+ self._batch: torch.Tensor | None = None
+ self._contexts: list[dict] = []
+ self._image_batch_sizes: list[int] = []
+ self._predictions: list = []
+
+ @abstractmethod
+ def predict(self, inputs: torch.Tensor) -> list[dict[str, dict[str, np.ndarray]]]:
+ """Makes predictions from a model input and output
+
+ Args:
+ the inputs to the model, of shape (batch_size, ...)
+
+ Returns:
+ the predictions for each of the 'batch_size' inputs
+ """
+
+ @torch.no_grad()
+ def inference(
+ self,
+ images: (
+ Iterable[str | Path | np.ndarray]
+ | Iterable[tuple[str | Path | np.ndarray, dict[str, Any]]]
+ ),
+ shelf_writer: shelving.ShelfWriter | None = None,
+ ) -> list[dict[str, np.ndarray]]:
+ """Run model inference on the given dataset
+
+ TODO: Add an option to also return head outputs (such as heatmaps)? Can be
+ super useful for debugging
+
+ Args:
+ images: the images to run inference on, optionally with context
+ shelf_writer: by default, data are saved in a list and returned at the end
+ of inference. Passing a shelf manager writes data to disk on-the-fly
+ using a "shelf" (a pickle-based, persistent, database-like object by
+ default, resulting in constant memory footprint). The returned list is
+ then empty.
+
+ Returns:
+ a dict containing head predictions for each image
+ [
+ {
+ "bodypart": {"poses": np.array},
+ "unique_bodypart": {"poses": np.array},
+ }
+ ]
+ """
+ self.model.to(self.device)
+ self.model.eval()
+
+ results = []
+ for data in images:
+ self._prepare_inputs(data)
+ self._process_full_batches()
+ results += self._extract_results(shelf_writer)
+
+ # Process the last batch even if not full
+ if self._inputs_waiting_for_processing():
+ self._process_batch()
+ results += self._extract_results(shelf_writer)
+
+ return results
+
+ def _prepare_inputs(
+ self,
+ data: str | Path | np.ndarray | tuple[str | Path | np.ndarray, dict],
+ ) -> None:
+ """
+ Prepares inputs for an image and adds them to the data ready to be processed
+ """
+ if isinstance(data, (str, Path, np.ndarray)):
+ inputs, context = data, {}
+ else:
+ inputs, context = data
+
+ if self.preprocessor is not None:
+ inputs, context = self.preprocessor(inputs, context)
+ else:
+ inputs = torch.as_tensor(inputs)
+
+ self._contexts.append(context)
+ self._image_batch_sizes.append(len(inputs))
+
+ # skip when there are no inputs for an image
+ if len(inputs) == 0:
+ return
+
+ if self._batch is None:
+ self._batch = inputs
+ else:
+ self._batch = torch.cat([self._batch, inputs], dim=0)
+
+ def _process_full_batches(self) -> None:
+ """Processes prepared inputs in batches of the desired batch size."""
+ while self._batch is not None and len(self._batch) >= self.batch_size:
+ self._process_batch()
+
+ def _extract_results(self, shelf_writer: shelving.ShelfWriter) -> list:
+ """Obtains results that were obtained from processing a batch."""
+ results = []
+ while (
+ len(self._image_batch_sizes) > 0
+ and len(self._predictions) >= self._image_batch_sizes[0]
+ ):
+ num_predictions = self._image_batch_sizes[0]
+ image_predictions = self._predictions[:num_predictions]
+ context = self._contexts[0]
+
+ if self.postprocessor is not None:
+ # TODO: Should we return context?
+ # TODO: typing update - the post-processor can remove a dict level
+ image_predictions, _ = self.postprocessor(image_predictions, context)
+
+ if shelf_writer is not None:
+ shelf_writer.add_prediction(
+ bodyparts=image_predictions["bodyparts"],
+ unique_bodyparts=image_predictions.get("unique_bodyparts"),
+ identity_scores=image_predictions.get("identity_scores"),
+ features=image_predictions.get("features"),
+ )
+ else:
+ results.append(image_predictions)
+
+ self._contexts = self._contexts[1:]
+ self._image_batch_sizes = self._image_batch_sizes[1:]
+ self._predictions = self._predictions[num_predictions:]
+
+ return results
+
+ def _process_batch(self) -> None:
+ """
+ Processes a batch. There must be inputs waiting to be processed before this is
+ called, otherwise this method will raise an error.
+ """
+ batch = self._batch[: self.batch_size]
+ self._predictions += self.predict(batch)
+
+ # remove processed inputs from batch
+ if len(self._batch) <= self.batch_size:
+ self._batch = None
+ else:
+ self._batch = self._batch[self.batch_size :]
+
+ def _inputs_waiting_for_processing(self) -> bool:
+ """Returns: Whether there are inputs which have not yet been processed"""
+ return self._batch is not None and len(self._batch) > 0
+
+
+class PoseInferenceRunner(InferenceRunner[PoseModel]):
+ """Runner for pose estimation inference"""
+
+ def __init__(
+ self,
+ model: PoseModel,
+ dynamic: DynamicCropper | None = None,
+ **kwargs,
+ ):
+ super().__init__(model, **kwargs)
+ self.dynamic = dynamic
+ if dynamic is not None:
+ print(
+ f"Inference runner using dynamic cropping: {self.dynamic}.\n"
+ "Note that dynamic cropping should only be used to analyze videos with "
+ "bottom-up pose estimation models."
+ )
+ if self.batch_size != 1:
+ raise ValueError(
+ "Dynamic cropping can only be used with batch size 1. Please set "
+ "your batch size to 1."
+ )
+
+ def predict(self, inputs: torch.Tensor) -> list[dict[str, dict[str, np.ndarray]]]:
+ """Makes predictions from a model input and output
+
+ Args:
+ the inputs to the model, of shape (batch_size, ...)
+
+ Returns:
+ predictions for each of the 'batch_size' inputs, made by each head, e.g.
+ [
+ {
+ "bodypart": {"poses": np.ndarray},
+ "unique_bodypart": {"poses": np.ndarray},
+ }
+ ]
+ """
+ if self.dynamic is not None:
+ inputs = self.dynamic.crop(inputs)
+
+ outputs = self.model(inputs.to(self.device))
+ raw_predictions = self.model.get_predictions(outputs)
+
+ if self.dynamic is not None:
+ self.dynamic.update(raw_predictions["bodypart"]["poses"])
+
+ predictions = [
+ {
+ head: {
+ pred_name: pred[b].cpu().numpy()
+ for pred_name, pred in head_outputs.items()
+ }
+ for head, head_outputs in raw_predictions.items()
+ }
+ for b in range(len(inputs))
+ ]
+ return predictions
+
+
+class DetectorInferenceRunner(InferenceRunner[BaseDetector]):
+ """Runner for object detection inference"""
+
+ def __init__(self, model: BaseDetector, **kwargs):
+ """
+ Args:
+ model: The detector to use for inference.
+ **kwargs: Inference runner kwargs.
+ """
+ super().__init__(model, **kwargs)
+
+ def predict(self, inputs: torch.Tensor) -> list[dict[str, dict[str, np.ndarray]]]:
+ """Makes predictions from a model input and output
+
+ Args:
+ the inputs to the model, of shape (batch_size, ...)
+
+ Returns:
+ predictions for each of the 'batch_size' inputs, made by each head, e.g.
+ [
+ {
+ "bodypart": {"poses": np.ndarray},
+ "unique_bodypart": "poses": np.ndarray},
+ ]
+ """
+ _, raw_predictions = self.model(inputs.to(self.device))
+ predictions = [
+ {
+ "detection": {
+ "bboxes": item["boxes"].cpu().numpy().reshape(-1, 4),
+ "scores": item["scores"].cpu().numpy().reshape(-1),
+ }
+ }
+ for item in raw_predictions
+ ]
+ return predictions
+
+
+def build_inference_runner(
+ task: Task,
+ model: nn.Module,
+ device: str,
+ snapshot_path: str | Path,
+ batch_size: int = 1,
+ preprocessor: Preprocessor | None = None,
+ postprocessor: Postprocessor | None = None,
+ dynamic: DynamicCropper | None = None,
+ load_weights_only: bool | None = None,
+) -> InferenceRunner:
+ """
+ Build a runner object according to a pytorch configuration file
+
+ Args:
+ task: the inference task to run
+ model: the model to run
+ device: the device to use (e.g. {'cpu', 'cuda:0', 'mps'})
+ snapshot_path: the snapshot from which to load the weights
+ batch_size: the batch size to use to run inference
+ preprocessor: the preprocessor to use on images before inference
+ postprocessor: the postprocessor to use on images after inference
+ dynamic: The DynamicCropper used for video inference, or None if dynamic
+ cropping should not be used. Only for bottom-up pose estimation models.
+ Should only be used when creating inference runners for video pose
+ estimation with batch size 1.
+ load_weights_only: Value for the torch.load() `weights_only` parameter.
+ If False, the python pickle module is used implicitly, which is known to
+ be insecure. Only set to False if you're loading data that you trust (e.g.
+ snapshots that you created). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+
+ Returns:
+ The inference runner.
+ """
+ kwargs = dict(
+ model=model,
+ device=device,
+ snapshot_path=snapshot_path,
+ batch_size=batch_size,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ load_weights_only=load_weights_only,
+ )
+ if task == Task.DETECT:
+ if dynamic is not None:
+ raise ValueError(
+ f"The DynamicCropper can only be used for pose estimation; not object "
+ f"detection. Please turn off dynamic cropping."
+ )
+ return DetectorInferenceRunner(**kwargs)
+
+ if task != Task.BOTTOM_UP:
+ if dynamic is not None:
+ print(
+ "Turning off dynamic cropping. It should only be used for bottom-up "
+ f"pose estimation models, but you are using a {task} model."
+ )
+ dynamic = None
+
+ return PoseInferenceRunner(dynamic=dynamic, **kwargs)
diff --git a/deeplabcut/pose_estimation_pytorch/runners/logger.py b/deeplabcut/pose_estimation_pytorch/runners/logger.py
new file mode 100644
index 0000000000..deef219671
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/logger.py
@@ -0,0 +1,436 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import csv
+import logging
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any, Optional
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+from torch.utils.data import DataLoader
+from torchvision.utils import draw_bounding_boxes, draw_keypoints
+
+try:
+ import wandb
+ has_wandb = True
+except ImportError:
+ has_wandb = False
+
+import deeplabcut.pose_estimation_pytorch.registry as deeplabcut_pose_estimation_pytorch_registry
+from deeplabcut.pose_estimation_pytorch.models.model import PoseModel
+
+LOGGER = deeplabcut_pose_estimation_pytorch_registry.Registry(
+ "loggers", build_func=deeplabcut_pose_estimation_pytorch_registry.build_from_cfg
+)
+
+
+def setup_file_logging(filepath: Path) -> None:
+ """
+ Sets up logging to a file
+
+ Args:
+ filepath: the path where logs should be saved
+ """
+ logging.basicConfig(
+ filename=filepath,
+ filemode="a",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO,
+ format="%(asctime)-15s %(message)s",
+ force=True,
+ )
+ console_logger = logging.StreamHandler()
+ console_logger.setLevel(logging.INFO)
+ root = logging.getLogger()
+ root.addHandler(console_logger)
+
+
+def destroy_file_logging() -> None:
+ """Resets the logging module to log everything to the console"""
+ root = logging.getLogger()
+ handlers = [h for h in root.handlers]
+ for handler in handlers:
+ root.removeHandler(handler)
+
+
+class BaseLogger(ABC):
+ """Base class for logging training runs"""
+
+ @abstractmethod
+ def log_config(self, config: dict = None) -> None:
+ """Logs the configuration data for a training run
+
+ Args:
+ config: the training configuration used for the run
+ """
+
+ @abstractmethod
+ def log(self, metrics: dict[str, Any], step: Optional[int] = None) -> None:
+ """Logs data from a training run
+
+ Args:
+ metrics: the metrics to log
+ step: The global step in processing. Defaults to None.
+ """
+
+ @abstractmethod
+ def save(self) -> None:
+ """Saves the current training logs"""
+
+
+class ImageLoggerMixin(ABC):
+ """Mixin for loggers that can log images
+
+ Before starting training, you should call `select_images_to_log`, which will
+ select a train and a test image for which inputs/outputs will always be logged.
+ Then logger.log_images should be called at every step - the logger will check if
+ anything needs to be uploaded, and take care of it.
+
+ Example:
+ project_name = "example"
+ run_name = "run-1"
+ logger = WandbLogger(project_name, run_name)
+ logger.select_images_to_log(train_loader, test_loader)
+
+ for i in range(epochs):
+ for batch_inputs in train_loader:
+ batch_labels = batch_data["annotations"]
+ batch_inputs = batch_data["image"]
+ batch_outputs = model(batch_inputs)
+ batch_targets = model.get_target(batch_outputs, batch_labels)
+ loss = criterion(batch_targets, batch_outputs)
+ loss.backwards()
+ optim.step()
+
+ logger.log_images(batch_inputs, batch_outputs, batch_targets)
+
+ for batch_inputs in train_loader:
+ ...
+ logger.log_images(batch_inputs, batch_outputs, batch_targets)
+ """
+
+ def __init__(self, image_log_interval: int | None = None, *args, **kwargs):
+ """"""
+ super().__init__(*args, **kwargs)
+ self.image_log_interval = image_log_interval
+ self._logged = {}
+ self._denormalize = transforms.Compose(
+ [
+ transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
+ transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
+ ]
+ )
+ self._softmax = torch.nn.Softmax2d()
+
+ @abstractmethod
+ def log_images(
+ self,
+ inputs: dict[str, Any],
+ outputs: dict[str, torch.Tensor],
+ targets: dict[str, dict[str, torch.Tensor]],
+ step: int,
+ ) -> None:
+ """Log images for a batch
+
+ Args:
+ inputs: the inputs for the model, containing at least an "image" key
+ outputs: the outputs of each model head
+ targets: the targets for each model head
+ step: the current step
+ """
+ pass
+
+ def select_images_to_log(self, train: DataLoader, valid: DataLoader) -> None:
+ """Selects the train and test images to log
+
+ Args:
+ train: the training dataloader
+ valid: the inference dataloader
+ """
+ def _caption(image_path: str) -> str:
+ p = Path(image_path)
+ return f"{p.parent.name}.{p.stem}"
+
+ train_image = train.dataset[0]["path"]
+ test_image = valid.dataset[0]["path"]
+ self._logged = {
+ train_image: {"name": "train-0", "caption": _caption(train_image)},
+ test_image: {"name": "test-0", "caption": _caption(test_image)},
+ }
+
+ def _prepare_image(
+ self,
+ image: torch.Tensor,
+ denormalize: bool = False,
+ keypoints: torch.Tensor | None = None,
+ bboxes: torch.Tensor | None = None,
+ ) -> np.ndarray:
+ """
+ Args:
+ image: the image to log, of shape (C, H, W), of any data type
+ denormalize: whether to remove ImageNet channel normalization
+ keypoints: size (num_instances, K, 2) the K keypoints location
+ bboxes: size (N, 4) containing bboxes in (xmin, ymin, xmax, ymax)
+
+ Returns:
+ an uint8 array with keypoints and bounding boxes drawn
+ """
+ if denormalize:
+ image = self._denormalize(image.unsqueeze(0)).squeeze()
+
+ image = F.convert_image_dtype(image.detach().cpu(), dtype=torch.uint8)
+ if keypoints is not None and len(keypoints) > 0:
+ assert len(keypoints.shape) == 3
+ # Use visibility and force torchvision >= 0.18
+ # pytorch.org/vision/0.18/generated/torchvision.utils.draw_keypoints.html
+ # pytorch.org/vision/0.17/generated/torchvision.utils.draw_keypoints.html
+ keypoints[torch.any(torch.isnan(keypoints), dim=-1)] = -1
+ image = draw_keypoints(
+ image, keypoints=keypoints[..., :2], colors="red", radius=5
+ )
+
+ if bboxes is not None and len(bboxes) > 0:
+ assert len(bboxes.shape) == 2
+ image = draw_bounding_boxes(image, boxes=bboxes[:, :4], width=1)
+
+ return image.permute(1, 2, 0).numpy()
+
+ def _heatmap_softmax(self, heatmaps: torch.Tensor) -> torch.Tensor:
+ """Applies a softmax to the heatmap channels"""
+ return self._softmax(heatmaps.detach().cpu())
+
+ def _prepare_images(
+ self,
+ inputs: dict[str, Any],
+ outputs: dict[str, dict[str, torch.Tensor]],
+ targets: dict[str, dict[str, dict[str, torch.Tensor]]],
+ ) -> dict[str, np.ndarray]:
+ """Prepares images for logging"""
+ image_logs = {}
+ paths = inputs["path"]
+ images_to_log = [(i, p) for i, p in enumerate(paths) if p in self._logged]
+ for idx, path in images_to_log:
+ base = self._logged[path]["name"]
+ keypoints = inputs.get("annotations", {}).get("keypoints")
+ if keypoints is not None:
+ keypoints = keypoints[idx]
+ image_logs[f"{base}.input"] = self._prepare_image(
+ inputs["image"][idx], keypoints=keypoints, denormalize=True,
+ )
+
+ for head, head_outputs in outputs.items():
+ if "heatmap" in head_outputs:
+ head_heatmaps = self._heatmap_softmax(head_outputs["heatmap"][idx])
+ head_targets = targets[head]["heatmap"]["target"][idx]
+ for j, (h, t) in enumerate(zip(head_heatmaps, head_targets)):
+ h = self._prepare_image(h.unsqueeze(0))
+ t = self._prepare_image(t.unsqueeze(0))
+ image_logs[f"{base}.heatmap.{j}"] = np.concatenate([h, t])
+
+ return image_logs
+
+
+@LOGGER.register_module
+class WandbLogger(ImageLoggerMixin, BaseLogger):
+ """Wandb logger to track experiments and log data.
+
+ Refer to: https://docs.wandb.ai/guides for more information on wandb.
+
+ Attributes:
+ run (wandb.Run): The wandb run object associated with the current experiment.
+ """
+
+ def __init__(
+ self,
+ project_name: str = "deeplabcut",
+ run_name: str = "tmp",
+ image_log_interval: int | None = None,
+ model: PoseModel = None,
+ **wandb_kwargs,
+ ) -> None:
+ """Initialize the WandbLogger class.
+
+ Args:
+ project_name: The name of the wandb project. Defaults to "deeplabcut".
+ run_name: The name of the wandb run. Defaults to "tmp".
+ image_log_interval: How often train/test images are logged in epochs (if
+ None, train/test inputs are never logged).
+ model: The model to log. Defaults to None.
+ wandb_kwargs: extra arguments to pass to ``wb.init``
+
+ Example:
+ logger = WandbLogger(project_name="mice", run_name="exp1", model=my_model)
+
+ """
+ super().__init__(image_log_interval=image_log_interval)
+
+ if not has_wandb:
+ raise ValueError(
+ "Cannot use ``WandbLogger`` as wandb is not installed. Please run"
+ "``pip install wandb`` if you want to log to wandb"
+ )
+
+ if wandb.run is not None:
+ wandb.finish()
+
+ self.run = wandb.init(
+ project=project_name,
+ name=run_name,
+ **wandb_kwargs,
+ )
+ if model is None:
+ raise ValueError("Specify the model to track!")
+ self.run.watch(model)
+
+ def log(self, metrics: dict[str, Any], step: Optional[int] = None) -> None:
+ """Logs metrics from runs
+
+ Args:
+ metrics: the metrics to log
+ step: The global step in processing. Defaults to None.
+
+ Example:
+ logger = WandbLogger()
+ logger.log({"loss": 0.123}, step=100)
+ """
+ self.run.log(metrics, step=step)
+
+ def log_images(
+ self,
+ inputs: dict[str, Any],
+ outputs: dict[str, dict[str, torch.Tensor]],
+ targets: dict[str, dict[str, dict[str, torch.Tensor]]],
+ step: int,
+ ) -> None:
+ """Log images for a batch
+
+ Args:
+ inputs: the inputs for the model, containing at least an "image" key
+ outputs: the outputs of each model head
+ targets: the targets for each model head
+ step: the current step
+ """
+ if self.image_log_interval is None or step % self.image_log_interval != 0:
+ return
+
+ images = self._prepare_images(inputs, outputs, targets)
+ if len(images) > 0:
+ self.run.log(
+ {name: wandb.Image(image) for name, image in images.items()},
+ step=step,
+ )
+
+ def save(self):
+ """Syncs all files to wandb with the policy specified.
+
+ Notes:
+ self.run: A run is a unit of computation logged by wandb.
+ self.run.run.dir: The directory where files associated with the run are saved.
+
+ Example:
+ logger = WandbLogger()
+ # Training and logging
+ logger.save()
+ """
+ self.run.save(self.run.dir)
+
+ def log_config(self, config: dict = None) -> None:
+ """Updates the current run with the given config dict.
+
+ Notes:
+ self.run: A run is a unit of computation logged by wandb.
+ self.run.config: Config object associated with this run.
+
+ Args:
+ config: Experiment config file.
+
+ Example:
+ logger = WandbLogger()
+ config = {"learning_rate": 0.001, "batch_size": 32}
+ logger.log_config(config)
+
+ """
+ self.run.config.update(config)
+
+
+@LOGGER.register_module
+class CSVLogger(BaseLogger):
+ """Logger saving stats and metrics to a CSV file"""
+
+ def __init__(self, train_folder: Path, log_filename: str) -> None:
+ """Initialize the WandbLogger class.
+
+ Args:
+ train_folder: The path of the folder containing training files.
+ log_filename: The name of the file in which to store training stats
+ """
+ super().__init__()
+ self.train_folder = train_folder
+ self.log_filename = log_filename
+ self.log_file = train_folder / log_filename
+
+ self._steps: list[int] = []
+ self._metric_store: list[dict] = []
+ self._logged_metrics: set[str] = set()
+
+ def log(self, metrics: dict[str, Any], step: Optional[int] = None) -> None:
+ """Logs metrics from runs
+
+ Args:
+ metrics: the metrics to log
+ step: The global step in processing. Defaults to None.
+ """
+ if step is None:
+ if len(self._steps) == 0:
+ step = 0
+ else:
+ step = self._steps[-1] + 1
+
+ self._logged_metrics = self._logged_metrics.union(metrics.keys())
+ if len(self._steps) > 0 and step == self._steps[-1]:
+ self._metric_store[-1].update(metrics)
+ else:
+ self._steps.append(step)
+ self._metric_store.append(metrics)
+
+ self.save()
+
+ def save(self):
+ """Saves the metrics to the file system"""
+ logs = self._prepare_logs()
+ with open(self.log_file, 'w', newline='') as f:
+ writer = csv.writer(f)
+ writer.writerows(logs)
+
+ def log_config(self, config: dict = None) -> None:
+ """Does not do anything as the config should already be saved
+
+ Args:
+ config: Experiment config file.
+ """
+ pass
+
+ def _prepare_logs(self) -> list[list]:
+ """Prepares the data to log as a list of strings"""
+ if len(self._metric_store) == 0:
+ return []
+
+ metrics = list(sorted(self._logged_metrics))
+ logs = [["step"] + metrics]
+ for step, step_metrics in zip(self._steps, self._metric_store):
+ logs.append([step] + [step_metrics.get(m) for m in metrics])
+
+ return logs
diff --git a/deeplabcut/pose_estimation_pytorch/runners/schedulers.py b/deeplabcut/pose_estimation_pytorch/runners/schedulers.py
new file mode 100644
index 0000000000..3b4e2b2cda
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/schedulers.py
@@ -0,0 +1,130 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class LRListScheduler(_LRScheduler):
+ """
+ You can achieve increased performance and faster training by using a learning rate
+ that changes during training. A scheduler makes the learning rate adaptive. Given a
+ list of learning rates and milestones modifies the learning rate accordingly during
+ training.
+ """
+
+ def __init__(self, optimizer, milestones, lr_list, last_epoch=-1) -> None:
+ """
+ Args:
+ optimizer: optimizer used for learning.
+ milestones: number of epochs.
+ lr_list: learning rate list.
+ last_epoch: where to start the scheduler. (-1: start from beginning)
+
+ Examples:
+ input:
+ last_epoch = -1
+ verbose = False
+ milestones = [10, 30, 40]
+ lr_list = [[0.00001],[0.000005],[0.000001]]
+ """
+ self.milestones = milestones
+ self.lr_list = lr_list
+ super().__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ """Summary:
+ Given a milestones, get the corresponding learning rate.
+
+ Returns:
+ lr: learning rate value
+
+ Examples:
+ input: LRListScheduler object
+ output: learning rate (lr) = [0.001]
+ """
+ if self.last_epoch not in self.milestones:
+ return [group["lr"] for group in self.optimizer.param_groups]
+ return [lr for lr in self.lr_list[self.milestones.index(self.last_epoch)]]
+
+
+def build_scheduler(
+ scheduler_cfg: dict | None, optimizer: torch.optim.Optimizer
+) -> torch.optim.lr_scheduler.LRScheduler | None:
+ """Builds a scheduler from a configuration, if defined
+
+ Args:
+ scheduler_cfg: the configuration of the scheduler to build
+ optimizer: the optimizer the scheduler will be built for
+
+ Returns:
+ None if scheduler_cfg is None, otherwise the scheduler
+ """
+ if scheduler_cfg is None:
+ return None
+
+ if scheduler_cfg["type"] == "LRListScheduler":
+ scheduler = LRListScheduler
+ else:
+ scheduler = getattr(torch.optim.lr_scheduler, scheduler_cfg["type"])
+
+ parsed_params = {}
+ for param_name, param in scheduler_cfg["params"].items():
+ if isinstance(param, list):
+ param = [_parse_scheduler_param(p, optimizer) for p in param]
+ else:
+ param = _parse_scheduler_param(param, optimizer)
+
+ parsed_params[param_name] = param
+
+ return scheduler(optimizer=optimizer, **parsed_params)
+
+
+def _parse_scheduler_param(param: Any, optimizer: torch.optim.Optimizer) -> Any:
+ """Parses parameters so they're built as schedulers if they're configured as one"""
+ if isinstance(param, dict) and "type" in param:
+ param = build_scheduler(param, optimizer)
+
+ return param
+
+
+def load_scheduler_state(
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
+ state_dict: dict,
+) -> None:
+ """
+ Args:
+ scheduler: The scheduler for which to load the state dict.
+ state_dict: The state dict to load
+
+ Raises:
+ ValueError: if the state dict fails to load.
+ """
+ try:
+ scheduler.load_state_dict(state_dict)
+ except Exception as err:
+ raise ValueError(f"Failed to load state dict: {err}")
+
+ param_groups = scheduler.optimizer.param_groups
+ resume_lrs = scheduler.get_last_lr()
+
+ if len(param_groups) != len(resume_lrs):
+ raise ValueError(
+ f"Number of optimizer parameter groups ({len(param_groups)}) did not match "
+ f"number of learning rates to resume from ({len(scheduler.get_last_lr())})."
+ )
+
+ # Update the learning rate for the optimizer based on the scheduler
+ for group, resume_lr in zip(param_groups, resume_lrs):
+ group['lr'] = resume_lr
diff --git a/deeplabcut/pose_estimation_pytorch/runners/shelving.py b/deeplabcut/pose_estimation_pytorch/runners/shelving.py
new file mode 100644
index 0000000000..a23fd2ccba
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/shelving.py
@@ -0,0 +1,227 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Modules used to read/write shelve data during video analysis in DeepLabCut 3.0"""
+import pickle
+import shelve
+from abc import ABC
+from pathlib import Path
+
+import numpy as np
+
+
+class ShelfManager(ABC):
+ """Class to manage shelf data"""
+
+ def __init__(self, filepath: str | Path, flag: str = "r") -> None:
+ self.filepath = Path(filepath)
+ self.flag = flag
+
+ self._db: shelve.Shelf | None = None
+ self._open: bool = False
+
+ def open(self) -> None:
+ """Opens the shelf"""
+ self._db = shelve.open(
+ str(self.filepath),
+ flag=self.flag,
+ protocol=pickle.DEFAULT_PROTOCOL,
+ )
+ self._open = True
+
+ def close(self) -> None:
+ """Closes the shelf"""
+ if not self._open:
+ return
+
+ try:
+ self._db.close()
+ except AttributeError:
+ pass
+
+ self._open = False
+
+ def keys(self) -> list[str]:
+ if not self._open:
+ raise ValueError(f"You must call open() before reading keys!")
+
+ return [k for k in self._db]
+
+
+class ShelfReader(ShelfManager):
+ """Reads data from a shelf"""
+
+ def __getitem__(self, item: str) -> dict:
+ """Reads an item from the shelf.
+
+ Args:
+ item: The key of the item to read.
+
+ Returns:
+ The item.
+ """
+ if not self._open:
+ raise ValueError(f"You must call open() before reading data!")
+
+ return self._db[item]
+
+
+class ShelfWriter(ShelfManager):
+ """Writes data to a shelf on-the-fly during video analysis.
+
+ Args:
+ pose_cfg: The test pose config for the model.
+ filepath: The path where the data should be saved.
+ num_frames: The number of frames in the video. Used to set the number of
+ leading 0s in the keys of the dictionary. Default is 5 if the number of
+ frames is not given.
+
+ Attributes:
+ filepath: The path to the shelf.
+ """
+
+ def __init__(
+ self, pose_cfg: dict, filepath: str | Path, num_frames: int | None = None
+ ):
+ super().__init__(filepath, flag="c")
+ self._pose_cfg = pose_cfg
+ self._num_frames = num_frames
+ self._frame_index = 0
+
+ self._str_width = 5
+ if num_frames is not None:
+ self._str_width = int(np.ceil(np.log10(num_frames)))
+
+ def add_prediction(
+ self,
+ bodyparts: np.ndarray,
+ unique_bodyparts: np.ndarray | None = None,
+ identity_scores: np.ndarray | None = None,
+ **kwargs,
+ ) -> None:
+ """Adds the prediction for a frame to the shelf
+
+ Args:
+ bodyparts: The predicted bodyparts.
+ unique_bodyparts: The predicted unique bodyparts, if there are any.
+ identity_scores: The predicted identities, if there are any.
+ """
+ if not self._open:
+ raise ValueError(f"You must call open() before adding data!")
+
+ key = "frame" + str(self._frame_index).zfill(self._str_width)
+
+ # convert bodyparts to shape (num_bpts, num_assemblies, 3)
+ bodyparts = bodyparts.transpose((1, 0, 2))
+ coordinates = [bpt[:, :2] for bpt in bodyparts]
+ scores = [bpt[:, 2:3] for bpt in bodyparts]
+
+ # full pickle has bodyparts and unique bodyparts in same array
+ unique_bodyparts = kwargs.get("unique_bodyparts", None)
+ if unique_bodyparts is not None:
+ unique_bpts = unique_bodyparts.transpose((1, 0, 2))
+ coordinates += [bpt[:, :2] for bpt in unique_bpts]
+ scores += [bpt[:, 2:] for bpt in unique_bpts]
+
+ output = dict(coordinates=(coordinates,), confidence=scores, costs=None)
+
+ identity_scores = kwargs.get("identity_scores", None)
+ if identity_scores is not None:
+ # Reshape id scores from (num_assemblies, num_bpts, num_individuals)
+ # to the original DLC full pickle format: (num_bpts, num_assem, num_ind)
+ id_scores = identity_scores.transpose((1, 0, 2))
+ output["identity"] = [bpt_id_scores for bpt_id_scores in id_scores]
+
+ if unique_bodyparts is not None:
+ # needed for create_video_with_all_detections to display unique bpts
+ num_unique = unique_bodyparts.shape[1]
+ num_assem, num_ind = id_scores.shape[1:]
+ output["identity"] += [
+ -1 * np.ones((num_assem, num_ind)) for i in range(num_unique)
+ ]
+
+ self._db[key] = output
+ self._frame_index += 1
+
+ def close(self) -> None:
+ """Opens the shelf"""
+ if self._open and self._frame_index > 0:
+ self._db["metadata"]["nframes"] = self._frame_index
+
+ super().close()
+
+ def open(self) -> None:
+ """Opens the shelf"""
+ super().open()
+ self._frame_index = 0
+
+ all_joints = self._pose_cfg["all_joints"]
+ paf_graph = self._pose_cfg.get("partaffinityfield_graph", [])
+
+ self._db["metadata"] = {
+ "nms radius": self._pose_cfg.get("nmsradius"),
+ "minimal confidence": self._pose_cfg.get("minconfidence"),
+ "sigma": self._pose_cfg.get("sigma", 1),
+ "PAFgraph": paf_graph,
+ "PAFinds": self._pose_cfg.get("paf_best", np.arange(len(paf_graph))),
+ "all_joints": [[i] for i in range(len(all_joints))],
+ "all_joints_names": [
+ self._pose_cfg["all_joints_names"][i] for i in range(len(all_joints))
+ ],
+ "nframes": self._num_frames,
+ "key_str_width": self._str_width,
+ }
+
+
+class FeatureShelfWriter(ShelfWriter):
+ """Writes bodypart features to a shelf on-the-fly for ReID model training.
+
+ Args:
+ pose_cfg: The test pose config for the model.
+ filepath: The path where the data should be saved.
+ num_frames: The number of frames in the video. Used to set the number of
+ leading 0s in the keys of the dictionary. Default is 5 if the number of
+ frames is not given.
+
+ Attributes:
+ filepath: The path to the shelf.
+ """
+
+ def __init__(
+ self, pose_cfg: dict, filepath: str | Path, num_frames: int | None = None
+ ):
+ super().__init__(pose_cfg, filepath, num_frames)
+
+ def add_prediction(
+ self,
+ bodyparts: np.ndarray,
+ features: np.ndarray | None = None,
+ **kwargs,
+ ) -> None:
+ """Adds the prediction for a frame to the shelf
+
+ Args:
+ bodyparts: The predicted bodyparts.
+ features: The features for the bodyparts.
+ """
+ if not self._open:
+ raise ValueError(f"You must call open() before adding data!")
+
+ key = "frame" + str(self._frame_index).zfill(self._str_width)
+
+ # bodyparts to shape (num_assemblies, num_bpts, xy)
+ coordinates = bodyparts[:, :, :2]
+ if features is None:
+ raise ValueError(
+ "Backbone features must be given to the FeatureShelfWriter"
+ )
+
+ self._db[key] = dict(coordinates=coordinates, features=features)
+ self._frame_index += 1
diff --git a/deeplabcut/pose_estimation_pytorch/runners/snapshots.py b/deeplabcut/pose_estimation_pytorch/runners/snapshots.py
new file mode 100755
index 0000000000..74203f5be1
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/snapshots.py
@@ -0,0 +1,217 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Code to handle storing models"""
+from __future__ import annotations
+
+import re
+import warnings
+from dataclasses import dataclass, field
+from pathlib import Path
+
+import numpy as np
+import torch
+
+
+@dataclass(frozen=True)
+class Snapshot:
+ best: bool
+ epochs: int | None
+ path: Path
+
+ def uid(self) -> str:
+ return self.path.stem.split("-")[-1]
+
+ @staticmethod
+ def from_path(path: Path) -> "Snapshot":
+ best = "-best" in path.stem
+ epochs = int(path.stem.split("-")[-1])
+ return Snapshot(best=best, epochs=epochs, path=path)
+
+
+@dataclass
+class TorchSnapshotManager:
+ """Class handling model checkpoint I/O
+
+ Attributes:
+ snapshot_prefix: The prefix to use when saving snapshots.
+ model_folder: The path to the directory where model snapshots should be stored.
+ key_metric: If defined, the metric is used to save the best model. Otherwise no
+ best model is used.
+ key_metric_asc: Whether the key metric is ascending (larger values are better).
+ max_snapshots: The maximum number of snapshots to store for the training run.
+ This does not include the best model (e.g., setting max_snapshots=5 will
+ mean that the 5 latest models will be kept, plus the best model)
+ save_epochs: The number of epochs between each model save
+ save_optimizer_state: Whether to store the optimizer state. This makes snapshots
+ much heavier, but allows to resume training as if it was never stopped.
+
+ Examples:
+ # Storing snapshots while training
+ model: nn.Module
+ loader = DLCLoader(...)
+ snapshot_manager = TorchSnapshotManager(
+ "snapshot",
+ loader.model_folder,
+ key_metric="test.mAP",
+ )
+ ...
+ for epoch in range(num_epochs):
+ train_epoch(model, data)
+ snapshot_manager.update({
+ "metadata": {
+ "metrics": {"mAP": ...}
+ },
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict()
+ })
+ """
+
+ snapshot_prefix: str
+ model_folder: Path
+ key_metric: str | None = None
+ key_metric_asc: bool = True
+ max_snapshots: int = 5
+ save_epochs: int = 25
+ save_optimizer_state: bool = False
+ _best_model_epochs: int = -1
+ _best_metric: float | None = None
+ _key: str = field(init=False)
+
+ def __post_init__(self):
+ assert self.max_snapshots > 0, f"max_snapshots must be a positive integer"
+ self._key = f"metrics/{self.key_metric}"
+
+ def update(self, epoch: int, state_dict: dict, last: bool = False) -> None:
+ """Saves the model state dict if the epoch is one that requires a save
+
+ Args:
+ epoch: the number of epochs the model was trained for
+ state_dict: the state dict to store
+ last: whether this is the last epoch in the training run, which forces a
+ model save no matter the epoch number
+
+ Returns:
+ the path to the saved snapshot if one
+ """
+ metrics = state_dict["metadata"]["metrics"]
+ if (
+ self._key in metrics
+ and not np.isnan(metrics[self._key])
+ and (
+ self._best_metric is None
+ or (self.key_metric_asc and self._best_metric < metrics[self._key])
+ or (not self.key_metric_asc and self._best_metric > metrics[self._key])
+ )
+ ):
+ current_best = self.best()
+ self._best_metric = metrics[self._key]
+
+ # Save the new best model
+ save_path = self.snapshot_path(epoch, best=True)
+ parsed_state_dict = {
+ k: v
+ for k, v in state_dict.items()
+ if self.save_optimizer_state or k != "optimizer"
+ }
+ torch.save(parsed_state_dict, save_path)
+
+ # Handle previous best model
+ if current_best is not None:
+ if current_best.epochs % self.save_epochs == 0:
+ new_name = self.snapshot_path(epoch=current_best.epochs)
+ current_best.path.rename(new_name)
+ else:
+ current_best.path.unlink(missing_ok=False)
+ elif last or epoch % self.save_epochs == 0:
+ # Save regular snapshot if needed
+ save_path = self.snapshot_path(epoch=epoch)
+ parsed_state_dict = {
+ k: v
+ for k, v in state_dict.items()
+ if self.save_optimizer_state or k != "optimizer"
+ }
+ torch.save(parsed_state_dict, save_path)
+
+ # Clean up old snapshots if needed
+ existing_snapshots = [s for s in self.snapshots() if not s.best]
+ if len(existing_snapshots) >= self.max_snapshots:
+ num_to_delete = len(existing_snapshots) - self.max_snapshots
+ to_delete = existing_snapshots[:num_to_delete]
+ for snapshot in to_delete:
+ snapshot.path.unlink(missing_ok=False)
+
+ def best(self) -> Snapshot | None:
+ """Returns: the path to the best snapshot, if it exists"""
+ snapshots = self.snapshots()
+ best_snapshots = [s for s in snapshots if s.best]
+ if len(best_snapshots) == 0:
+ return None
+
+ if len(best_snapshots) > 1:
+ warnings.warn(
+ f"TorchSnapshotManager.best(): found multiple best snapshots ("
+ f"{best_snapshots}), returning the last one."
+ )
+
+ best_snapshot = best_snapshots[-1]
+ return best_snapshot
+
+ def last(self) -> Snapshot | None:
+ """Returns: path to the last snapshot that was saved, if any snapshot exists"""
+ snapshots = self.snapshots(best_in_last=False)
+ if len(snapshots) == 0:
+ return None
+ return snapshots[-1]
+
+ def snapshots(self, best_in_last: bool = True) -> list[Snapshot]:
+ """
+ Args:
+ best_in_last: Whether to place the snapshot with the best performance in the
+ last position in the list, even if it wasn't the last epoch.
+
+ Returns:
+ The snapshots for a training run, sorted by the number of epochs they were
+ trained for. If ``best_in_last=True`` and a best snapshot exists, it will be
+ the last one in the list.
+ """
+
+ def _sort_key(snapshot: Snapshot) -> int:
+ return snapshot.epochs
+
+ def _sort_key_best_as_last(snapshot: Snapshot) -> tuple[int, int]:
+ return 1 if snapshot.best else 0, snapshot.epochs
+
+ pattern = r"^(" + self.snapshot_prefix + r"(-best)?-\d+\.pt)$"
+ snapshots = [
+ Snapshot.from_path(f)
+ for f in self.model_folder.iterdir()
+ if re.match(pattern, f.name)
+ ]
+
+ sort_key = _sort_key
+ if best_in_last:
+ sort_key = _sort_key_best_as_last
+ snapshots.sort(key=sort_key)
+ return snapshots
+
+ def snapshot_path(self, epoch: int, best: bool = False) -> Path:
+ """
+ Args:
+ epoch: the number of epochs for which a snapshot was trained
+ best: whether this is the best performing model for the training run
+
+ Returns:
+ the path where the model should be stored
+ """
+ uid = f"{epoch:03}"
+ if best:
+ uid = f"best-{uid}"
+ return self.model_folder / f"{self.snapshot_prefix}-{uid}.pt"
diff --git a/deeplabcut/pose_estimation_pytorch/runners/train.py b/deeplabcut/pose_estimation_pytorch/runners/train.py
new file mode 100644
index 0000000000..d9788c2922
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/runners/train.py
@@ -0,0 +1,765 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import logging
+from abc import ABCMeta, abstractmethod
+from collections import defaultdict
+from pathlib import Path
+from typing import Any, Generic
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.parallel import DataParallel
+from torch.utils.data import DataLoader
+
+import deeplabcut.core.metrics as metrics
+import deeplabcut.pose_estimation_pytorch.runners.schedulers as schedulers
+from deeplabcut.pose_estimation_pytorch.models.detectors import BaseDetector
+from deeplabcut.pose_estimation_pytorch.models.model import PoseModel
+from deeplabcut.pose_estimation_pytorch.runners.base import (
+ attempt_snapshot_load,
+ ModelType,
+ Runner,
+)
+from deeplabcut.pose_estimation_pytorch.runners.logger import (
+ BaseLogger,
+ CSVLogger,
+ ImageLoggerMixin,
+)
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import TorchSnapshotManager
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+class TrainingRunner(Runner, Generic[ModelType], metaclass=ABCMeta):
+ """Base TrainingRunner class.
+
+ A TrainingRunner is used to fit models to datasets. Subclasses must implement the
+ ``step(self, batch, mode)`` method, which performs a single training or validation
+ step on a batch of data. The step is different depending on the model type (e.g.
+ a pose model step vs. an object detector step).
+
+ Args:
+ model: The model to fit.
+ optimizer: The optimizer to use to fit the model.
+ snapshot_manager: Manages how snapshots are saved to disk during training.
+ device: The device on which to run training (e.g. 'cpu', 'cuda', 'cuda:0').
+ gpus: Used to specify the GPU indices for multi-GPU training (e.g. [0, 1, 2, 3]
+ to train on 4 GPUs). When a GPUs list is given, the device must be 'cuda'.
+ eval_interval: The interval at which the model will be evaluated while training
+ (e.g. `eval_interva=5` means the model will be evaluated every 5 epochs).
+ snapshot_path: If continuing to train a model, the path to the snapshot to
+ resume training from.
+ scheduler: The learning rate scheduler (or it's configuration), if one should be
+ used.
+ load_scheduler_state_dict: When resuming training (snapshot_path is not None),
+ attempts to load the scheduler state dict from the snapshot. If you've
+ modified your scheduler, set this to False or the old scheduler parameters
+ might be used.
+ logger: Logger to monitor training (e.g. a WandBLogger).
+ log_filename: Name of the file in which to store training stats.
+ load_weights_only: Value for the torch.load() `weights_only` parameter if
+ `snapshot_path` is not None.
+ If False, the python pickle module is used implicitly, which is known to
+ be insecure. Only set to False if you're loading data that you trust
+ (e.g. snapshots that you created yourself). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+ """
+
+ def __init__(
+ self,
+ model: ModelType,
+ optimizer: dict | torch.optim.Optimizer,
+ snapshot_manager: TorchSnapshotManager,
+ device: str = "cpu",
+ gpus: list[int] | None = None,
+ eval_interval: int = 1,
+ snapshot_path: str | Path | None = None,
+ scheduler: dict | torch.optim.lr_scheduler.LRScheduler | None = None,
+ load_scheduler_state_dict: bool = True,
+ logger: BaseLogger | None = None,
+ log_filename: str = "learning_stats.csv",
+ load_weights_only: bool | None = None,
+ ):
+ super().__init__(
+ model=model, device=device, gpus=gpus, snapshot_path=snapshot_path
+ )
+ if isinstance(optimizer, dict):
+ optimizer = build_optimizer(model, optimizer)
+ if isinstance(scheduler, dict):
+ scheduler = schedulers.build_scheduler(scheduler, optimizer)
+
+ self.eval_interval = eval_interval
+ self.optimizer = optimizer
+ self.scheduler = scheduler
+ self.snapshot_manager = snapshot_manager
+ self.history: dict[str, list] = dict(train_loss=[], eval_loss=[])
+ self.csv_logger = CSVLogger(
+ train_folder=snapshot_manager.model_folder,
+ log_filename=log_filename,
+ )
+ self.logger = logger
+ self.starting_epoch = 0
+ self.current_epoch = 0
+
+ # some models cannot compute a validation loss (e.g. detectors)
+ self._print_valid_loss = True
+
+ if self.snapshot_path:
+ snapshot = self.load_snapshot(
+ self.snapshot_path,
+ self.device,
+ self.model,
+ weights_only=load_weights_only,
+ )
+ self.starting_epoch = snapshot.get("metadata", {}).get("epoch", 0)
+
+ if "optimizer" in snapshot:
+ self.optimizer.load_state_dict(snapshot["optimizer"])
+
+ self._load_scheduler_state_dict(load_scheduler_state_dict, snapshot)
+
+ self._metadata = dict(epoch=self.starting_epoch, metrics=dict(), losses=dict())
+ self._epoch_ground_truth = {}
+ self._epoch_predictions = {}
+
+ def state_dict(self) -> dict:
+ """Returns: the state dict for the runner"""
+ model = self.model
+ if self._data_parallel:
+ model = self.model.module
+
+ state_dict_ = dict(
+ metadata=self._metadata,
+ model=model.state_dict(),
+ optimizer=self.optimizer.state_dict(),
+ )
+ if self.scheduler is not None:
+ state_dict_["scheduler"] = self.scheduler.state_dict()
+
+ return state_dict_
+
+ @abstractmethod
+ def step(
+ self, batch: dict[str, Any], mode: str = "train"
+ ) -> dict[str, torch.Tensor]:
+ """Perform a single epoch gradient update or validation step
+
+ Args:
+ batch: the batch data on which to run a step
+ mode: "train" or "eval". Defaults to "train".
+
+ Raises:
+ ValueError: if mode is not in {"train", "eval"}
+
+ Returns:
+ A dictionary containing the different losses for the step
+ """
+
+ @abstractmethod
+ def _compute_epoch_metrics(self) -> dict[str, float]:
+ """Computes the metrics using the data accumulated during an epoch
+
+ Returns:
+ A dictionary containing the different losses for the step
+ """
+ raise NotImplementedError
+
+ def fit(
+ self,
+ train_loader: DataLoader,
+ valid_loader: DataLoader,
+ epochs: int,
+ display_iters: int,
+ ) -> None:
+ """Train model for the specified number of steps.
+
+ Args:
+ train_loader: Data loader, which is an iterator over train instances.
+ Each batch contains image tensor and heat maps tensor input samples.
+ valid_loader: Data loader used for validation of the model.
+ epochs: The number of training epochs.
+ display_iters: The number of iterations between each loss print
+
+ Example:
+ runner = Runner(model, optimizer, cfg, device='cuda')
+ runner.fit(train_loader, valid_loader, "example/models" epochs=50)
+ """
+ if self._data_parallel:
+ self.model = DataParallel(self.model, device_ids=self._gpus).cuda()
+ else:
+ self.model.to(self.device)
+
+ if isinstance(self.logger, ImageLoggerMixin):
+ self.logger.select_images_to_log(train_loader, valid_loader)
+
+ # continuing to train a model: either total epochs or extra epochs
+ if self.starting_epoch > 0:
+ epochs = self.starting_epoch + epochs
+
+ for e in range(self.starting_epoch + 1, epochs + 1):
+ self.current_epoch = e
+ self._metadata["epoch"] = e
+ train_loss = self._epoch(
+ train_loader, mode="train", display_iters=display_iters
+ )
+ if self.scheduler:
+ self.scheduler.step()
+
+ lr = self.optimizer.param_groups[0]["lr"]
+ msg = f"Epoch {e}/{epochs} (lr={lr}), train loss {float(train_loss):.5f}"
+ if e % self.eval_interval == 0:
+ with torch.no_grad():
+ logging.info(f"Training for epoch {e} done, starting evaluation")
+ valid_loss = self._epoch(
+ valid_loader, mode="eval", display_iters=display_iters
+ )
+ if self._print_valid_loss:
+ msg += f", valid loss {float(valid_loss):.5f}"
+
+ self.snapshot_manager.update(e, self.state_dict(), last=(e == epochs))
+ logging.info(msg)
+
+ epoch_metrics = self._metadata.get("metrics")
+ if (
+ e % self.eval_interval == 0
+ and epoch_metrics is not None
+ and len(epoch_metrics) > 0
+ ):
+ logging.info(f"Model performance:")
+ line_length = max([len(name) for name in epoch_metrics.keys()]) + 2
+ for name, score in epoch_metrics.items():
+ logging.info(f" {(name + ':').ljust(line_length)}{score:6.2f}")
+
+ def _epoch(
+ self,
+ loader: torch.utils.data.DataLoader,
+ mode: str = "train",
+ display_iters: int = 500,
+ ) -> float:
+ """Facilitates training over an epoch. Returns the loss over the batches.
+
+ Args:
+ loader: Data loader, which is an iterator over instances.
+ Each batch contains image tensor and heat maps tensor input samples.
+ mode: str identifier to instruct the Runner whether to train or evaluate.
+ Possible values are: "train" or "eval".
+ display_iters: the number of iterations between each loss print
+
+ Raises:
+ ValueError: When the given mode is invalid
+
+ Returns:
+ epoch_loss: Average of the loss over the batches.
+ """
+ if mode == "train":
+ self.model.train()
+ elif mode == "eval" or mode == "inference":
+ self.model.eval()
+ else:
+ raise ValueError(f"Runner mode must be train or eval, found mode={mode}.")
+
+ epoch_loss = []
+ loss_metrics = defaultdict(list)
+ for i, batch in enumerate(loader):
+ losses_dict = self.step(batch, mode)
+ if "total_loss" in losses_dict:
+ epoch_loss.append(losses_dict["total_loss"])
+ if (i + 1) % display_iters == 0 and mode != "eval":
+ logging.info(
+ f"Number of iterations: {i + 1}, "
+ f"loss: {losses_dict['total_loss']:.5f}, "
+ f"lr: {self.optimizer.param_groups[0]['lr']}"
+ )
+
+ for key in losses_dict.keys():
+ loss_metrics[key].append(losses_dict[key])
+
+ perf_metrics = None
+ if mode == "eval":
+ perf_metrics = self._compute_epoch_metrics()
+ self._metadata["metrics"] = perf_metrics
+ self._epoch_predictions = {}
+ self._epoch_ground_truth = {}
+
+ if len(epoch_loss) > 0:
+ epoch_loss = np.mean(epoch_loss).item()
+ else:
+ epoch_loss = 0
+ self.history[f"{mode}_loss"].append(epoch_loss)
+
+ metrics_to_log = {}
+ if perf_metrics:
+ for name, score in perf_metrics.items():
+ if not isinstance(score, (int, float)):
+ score = 0.0
+ metrics_to_log[name] = score
+
+ for key in loss_metrics:
+ name = f"{mode}.{key}"
+ val = float("nan")
+ if np.sum(~np.isnan(loss_metrics[key])) > 0:
+ val = np.nanmean(loss_metrics[key]).item()
+ self._metadata["losses"][name] = val
+ metrics_to_log[f"losses/{name}"] = val
+
+ self.csv_logger.log(metrics_to_log, step=self.current_epoch)
+ if self.logger:
+ self.logger.log(metrics_to_log, step=self.current_epoch)
+
+ return epoch_loss
+
+ def _load_scheduler_state_dict(self, load_state_dict: bool, snapshot: dict) -> None:
+ if self.scheduler is None:
+ return
+
+ loaded_state_dict = False
+ if load_state_dict and "scheduler" in snapshot:
+ try:
+ schedulers.load_scheduler_state(self.scheduler, snapshot["scheduler"])
+ loaded_state_dict = True
+ except ValueError as err:
+ logging.warning(
+ "Failed to load the scheduler state_dict. The scheduler will "
+ "restart at epoch 0. This is expected if the scheduler "
+ "configuration was edited since the original snapshot was "
+ f"trained. Error: {err}"
+ )
+
+ if not loaded_state_dict and self.starting_epoch > 0:
+ logging.info(
+ f"Setting the scheduler starting epoch to {self.starting_epoch}"
+ )
+ self.scheduler.last_epoch = self.starting_epoch
+
+
+class PoseTrainingRunner(TrainingRunner[PoseModel]):
+ """Runner to train pose estimation models"""
+
+ def __init__(
+ self,
+ model: PoseModel,
+ optimizer: torch.optim.Optimizer,
+ load_head_weights: bool = True,
+ **kwargs,
+ ):
+ """
+ Args:
+ model: The neural network for solving pose estimation task.
+ optimizer: A PyTorch optimizer for updating model parameters.
+ load_head_weights: When `snapshot_path` is not None, whether to load the
+ head weights from the saved snapshot or just the backbone weights.
+ **kwargs: TrainingRunner kwargs
+ """
+ self._load_head_weights = load_head_weights
+ super().__init__(model, optimizer, **kwargs)
+
+ def load_snapshot(
+ self,
+ snapshot_path: str | Path,
+ device: str,
+ model: PoseModel,
+ weights_only: bool | None = None,
+ ) -> dict:
+ """Loads the state dict for a model from a file
+
+ This method loads a file containing a DeepLabCut PyTorch model snapshot onto
+ a given device, and sets the model weights using the state_dict.
+
+ Args:
+ snapshot_path: the path containing the model weights to load
+ device: the device on which the model should be loaded
+ model: the model for which the weights are loaded
+ weights_only: Value for torch.load() `weights_only` parameter.
+ If False, the python pickle module is used implicitly, which is known to
+ be insecure. Only set to False if you're loading data that you trust
+ (e.g. snapshots that you created yourself). For more information, see:
+ https://pytorch.org/docs/stable/generated/torch.load.html
+ If None, the default value is used:
+ `deeplabcut.pose_estimation_pytorch.get_load_weights_only()`
+
+ Returns:
+ The content of the snapshot file.
+ """
+ snapshot = attempt_snapshot_load(snapshot_path, device, weights_only)
+ if self._load_head_weights:
+ model.load_state_dict(snapshot["model"])
+ else:
+ backbone_prefix = "backbone."
+ backbone_weights = {
+ k[len(backbone_prefix) :]: v
+ for k, v in snapshot["model"].items()
+ if k.startswith(backbone_prefix)
+ }
+ model.backbone.load_state_dict(backbone_weights)
+
+ return snapshot
+
+ def step(
+ self, batch: dict[str, Any], mode: str = "train"
+ ) -> dict[str, torch.Tensor]:
+ """Perform a single epoch gradient update or validation step.
+
+ Args:
+ batch: Tuple of input image(s) and target(s) for train or valid single step.
+ mode: `train` or `eval`. Defaults to "train".
+
+ Raises:
+ ValueError: "Runner must be in train or eval mode, but {mode} was found."
+
+ Returns:
+ dict: {
+ "total_loss": aggregate_loss,
+ "aux_loss_1": loss_value,
+ ...,
+ }
+ """
+ if mode not in ["train", "eval"]:
+ raise ValueError(
+ f"BottomUpSolver must be in train or eval mode, but {mode} was found."
+ )
+
+ if mode == "train":
+ self.optimizer.zero_grad()
+
+ inputs = batch["image"]
+ inputs = inputs.to(self.device).float()
+ outputs = self.model(inputs)
+
+ if self._data_parallel:
+ underlying_model = self.model.module
+ else:
+ underlying_model = self.model
+
+ target = underlying_model.get_target(outputs, batch["annotations"])
+ losses_dict = underlying_model.get_loss(outputs, target)
+ if mode == "train":
+ losses_dict["total_loss"].backward()
+ self.optimizer.step()
+
+ if isinstance(self.logger, ImageLoggerMixin):
+ self.logger.log_images(batch, outputs, target, step=self.current_epoch)
+
+ if mode == "eval":
+ predictions = {
+ name: {k: v.detach().cpu().numpy() for k, v in pred.items()}
+ for name, pred in underlying_model.get_predictions(outputs).items()
+ }
+
+ ground_truth = batch["annotations"]["keypoints"]
+ if batch["annotations"]["with_center_keypoints"][0]:
+ ground_truth = ground_truth[..., :-1, :]
+
+ self._update_epoch_predictions(
+ name="bodyparts",
+ gt_keypoints=ground_truth,
+ pred_keypoints=predictions["bodypart"]["poses"],
+ offsets=batch["offsets"],
+ scales=batch["scales"],
+ )
+ if "unique_bodypart" in predictions:
+ self._update_epoch_predictions(
+ name="unique_bodyparts",
+ gt_keypoints=batch["annotations"]["keypoints_unique"],
+ pred_keypoints=predictions["unique_bodypart"]["poses"],
+ offsets=batch["offsets"],
+ scales=batch["scales"],
+ )
+
+ return {k: v.detach().cpu().numpy() for k, v in losses_dict.items()}
+
+ def _compute_epoch_metrics(self) -> dict[str, float]:
+ """Computes the metrics using the data accumulated during an epoch
+ Returns:
+ A dictionary containing the different losses for the step
+ """
+ scores = metrics.compute_metrics(
+ ground_truth=self._epoch_ground_truth["bodyparts"],
+ predictions=self._epoch_predictions["bodyparts"],
+ single_animal=False,
+ unique_bodypart_gt=self._epoch_ground_truth.get("unique_bodyparts"),
+ unique_bodypart_poses=self._epoch_predictions.get("unique_bodyparts"),
+ pcutoff=0.6,
+ compute_detection_rmse=False,
+ )
+ return {f"metrics/test.{metric}": value for metric, value in scores.items()}
+
+ def _update_epoch_predictions(
+ self,
+ name: str,
+ gt_keypoints: torch.Tensor,
+ pred_keypoints: torch.Tensor,
+ scales: torch.Tensor,
+ offsets: torch.Tensor,
+ ) -> None:
+ """Updates the stored predictions with a new batch"""
+ epoch_gt_metric = self._epoch_ground_truth.get(name, {})
+ epoch_metric = self._epoch_predictions.get(name, {})
+ assert len(gt_keypoints) == len(pred_keypoints)
+ assert len(offsets) == len(scales)
+ scales = scales.detach().cpu().numpy()
+ offsets = offsets.detach().cpu().numpy()
+
+ for gt, pred, scale, offset in zip(
+ gt_keypoints,
+ pred_keypoints,
+ scales,
+ offsets,
+ ):
+ ground_truth = gt.detach().cpu().numpy()
+ pred = pred.copy()
+
+ # rescale to the full image for TD or CTD
+ ground_truth[..., :2] = (ground_truth[..., :2] * scale) + offset
+ pred[..., :2] = (pred[..., :2] * scale) + offset
+
+ # we don't care about image paths here - use a default index
+ index = len(epoch_metric) + 1
+ epoch_gt_metric[f"sample{index:09}"] = ground_truth
+ epoch_metric[f"sample{index:09}"] = pred
+
+ self._epoch_ground_truth[name] = epoch_gt_metric
+ self._epoch_predictions[name] = epoch_metric
+
+
+class DetectorTrainingRunner(TrainingRunner[BaseDetector]):
+ """Runner to train object detection models"""
+
+ def __init__(self, model: BaseDetector, optimizer: torch.optim.Optimizer, **kwargs):
+ """
+ Args:
+ model: The detector model to train.
+ optimizer: The optimizer to use to train the model.
+ **kwargs: TrainingRunner kwargs
+ """
+ log_filename = "learning_stats_detector.csv"
+ if "log_filename" in kwargs:
+ log_filename = kwargs.pop("log_filename")
+
+ super().__init__(model, optimizer, log_filename=log_filename, **kwargs)
+ self._pycoco_warning_displayed = False
+ self._print_valid_loss = False
+
+ def step(
+ self, batch: dict[str, Any], mode: str = "train"
+ ) -> dict[str, torch.Tensor]:
+ """Perform a single epoch gradient update or validation step.
+
+ Args:
+ batch: Tuple of input image(s) and target(s) for train or valid single step.
+ mode: `train` or `eval`. Defaults to "train".
+
+ Raises:
+ ValueError: "Runner must be in train or eval mode, but {mode} was found."
+
+ Returns:
+ dict: {
+ 'total_loss': torch.Tensor,
+ 'aux_loss_1': torch.Tensor,
+ ...,
+ }
+ """
+ if mode not in ["train", "eval"]:
+ raise ValueError(
+ f"DetectorSolver must be in train or eval mode, but {mode} was found."
+ )
+
+ if mode == "train":
+ self.optimizer.zero_grad()
+ self.model.train()
+ else:
+ self.model.eval()
+
+ images = batch["image"]
+ images = images.to(self.device)
+
+ if self._data_parallel:
+ underlying_model = self.model.module
+ else:
+ underlying_model = self.model
+
+ target = underlying_model.get_target(batch["annotations"])
+ for item in target: # target is a list here
+ for key in item:
+ if item[key] is not None:
+ item[key] = item[key].to(self.device)
+
+ losses, predictions = self.model(images, target)
+
+ # losses only returned during training, not evaluation
+ if mode == "train":
+ losses["total_loss"] = sum(loss_part for loss_part in losses.values())
+ losses["total_loss"].backward()
+ self.optimizer.step()
+ losses = {k: v.detach().cpu().numpy() for k, v in losses.items()}
+
+ elif mode == "eval":
+ losses["total_loss"] = float("nan")
+ self._update_epoch_predictions(
+ paths=batch["path"],
+ sizes=batch["original_size"],
+ bboxes=batch["annotations"]["boxes"],
+ predictions=predictions,
+ offsets=batch["offsets"],
+ scales=batch["scales"],
+ )
+
+ return losses
+
+ def _compute_epoch_metrics(self) -> dict[str, float]:
+ """Returns: bounding box metrics, if"""
+ try:
+ return {
+ f"metrics/test.{k}": v
+ for k, v in metrics.compute_bbox_metrics(
+ self._epoch_ground_truth, self._epoch_predictions
+ ).items()
+ }
+ except ModuleNotFoundError:
+ if not self._pycoco_warning_displayed:
+ logging.info(
+ "\nNote:\n"
+ "Cannot compute bounding box metrics as ``pycocotools`` is not "
+ "installed. If you want bounding box mAP metrics when training "
+ "detectors for top-down models, please run ``pip install "
+ "pycocotools``.\n"
+ )
+ self._pycoco_warning_displayed = True
+
+ return {}
+
+ def _update_epoch_predictions(
+ self,
+ paths: torch.Tensor,
+ sizes: torch.Tensor,
+ bboxes: torch.Tensor,
+ predictions: list[dict[str, torch.Tensor]],
+ scales: torch.Tensor,
+ offsets: torch.Tensor,
+ ) -> None:
+ """Updates the stored predictions with a new batch"""
+ for img_path, img_size, img_bboxes, img_pred, scale, offset in zip(
+ paths, sizes, bboxes, predictions, scales, offsets
+ ):
+ scale_x, scale_y = scale
+ scale_factors = np.array([scale_x, scale_y, scale_x, scale_y])
+ offset = np.array(offset)
+
+ # remove bboxes that are not visible
+ img_bbox_mask = (img_bboxes[:, 2] > 0.0) & (img_bboxes[:, 3] > 0.0)
+ img_bboxes = img_bboxes[img_bbox_mask]
+
+ # rescale ground truth bounding boxes
+ gt_rescaled = img_bboxes.cpu().numpy() * scale_factors
+ gt_rescaled[..., :2] = gt_rescaled[..., :2] + offset
+
+ # convert to COCO format (xywh) before rescaling
+ pred_rescaled = img_pred["boxes"].detach().cpu().numpy()
+ pred_rescaled[:, 2] -= pred_rescaled[:, 0]
+ pred_rescaled[:, 3] -= pred_rescaled[:, 1]
+ pred_rescaled[..., :4] = pred_rescaled[..., :4] * scale_factors
+ pred_rescaled[..., :2] = pred_rescaled[..., :2] + offset
+
+ self._epoch_ground_truth[img_path] = {
+ "bboxes": gt_rescaled,
+ "width": img_size[1],
+ "height": img_size[0],
+ }
+ self._epoch_predictions[img_path] = {
+ "bboxes": pred_rescaled,
+ "scores": img_pred["scores"].detach().cpu().numpy(),
+ }
+
+
+def build_training_runner(
+ runner_config: dict,
+ model_folder: Path,
+ task: Task,
+ model: nn.Module,
+ device: str,
+ gpus: list[int] | None = None,
+ snapshot_path: str | Path | None = None,
+ load_head_weights: bool = True,
+ logger: BaseLogger | None = None,
+) -> TrainingRunner:
+ """
+ Build a runner object according to a pytorch configuration file
+
+ Args:
+ runner_config: the configuration for the runner
+ model_folder: the folder where models should be saved
+ task: the task the runner will perform
+ model: the model to run
+ device: the device to use (e.g. {'cpu', 'cuda:0', 'mps'})
+ gpus: the list of GPU indices to use for multi-GPU training
+ snapshot_path: the snapshot from which to load the weights
+ load_head_weights: When `snapshot_path` is not None and a pose model is being
+ trained, whether to load the head weights from the saved snapshot.
+ logger: the logger to use, if any
+
+ Returns:
+ the runner that was built
+ """
+ optimizer = build_optimizer(model, runner_config["optimizer"])
+ scheduler = schedulers.build_scheduler(runner_config.get("scheduler"), optimizer)
+
+ # if no custom snapshot prefix is defined, use the default one
+ snapshot_prefix = runner_config.get("snapshot_prefix")
+ if snapshot_prefix is None or len(snapshot_prefix) == 0:
+ snapshot_prefix = task.snapshot_prefix
+
+ kwargs = dict(
+ model=model,
+ optimizer=optimizer,
+ snapshot_manager=TorchSnapshotManager(
+ snapshot_prefix=snapshot_prefix,
+ model_folder=model_folder,
+ key_metric=runner_config.get("key_metric"),
+ key_metric_asc=runner_config.get("key_metric_asc"),
+ max_snapshots=runner_config["snapshots"]["max_snapshots"],
+ save_epochs=runner_config["snapshots"]["save_epochs"],
+ save_optimizer_state=runner_config["snapshots"]["save_optimizer_state"],
+ ),
+ device=device,
+ gpus=gpus,
+ eval_interval=runner_config.get("eval_interval"),
+ snapshot_path=snapshot_path,
+ scheduler=scheduler,
+ load_scheduler_state_dict=runner_config.get("load_scheduler_state_dict", True),
+ logger=logger,
+ load_weights_only=runner_config.get("load_weights_only", None),
+ )
+ if task == Task.DETECT:
+ return DetectorTrainingRunner(**kwargs)
+
+ kwargs["load_head_weights"] = load_head_weights
+ return PoseTrainingRunner(**kwargs)
+
+
+def build_optimizer(
+ model: nn.Module,
+ optimizer_config: dict,
+) -> torch.optim.Optimizer:
+ """Builds an optimizer from a configuration.
+
+ Args:
+ model: The model to optimize.
+ optimizer_config: The configuration for the optimizer.
+
+ Returns:
+ The optimizer for the model built according to the given configuration.
+ """
+ optim_cls = getattr(torch.optim, optimizer_config["type"])
+ optimizer = optim_cls(params=model.parameters(), **optimizer_config["params"])
+ return optimizer
diff --git a/deeplabcut/pose_estimation_pytorch/task.py b/deeplabcut/pose_estimation_pytorch/task.py
new file mode 100644
index 0000000000..944b9bc441
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/task.py
@@ -0,0 +1,41 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Types of tasks that can be run by DeepLabCut pose estimation models"""
+from __future__ import annotations
+
+from dataclasses import dataclass
+from enum import Enum
+
+
+@dataclass(frozen=True)
+class TaskDataMixin:
+ aliases: tuple[str]
+ snapshot_prefix: str
+
+
+class Task(TaskDataMixin, Enum):
+ """A task to solve"""
+
+ BOTTOM_UP = ("BU", "BottomUp"), "snapshot"
+ DETECT = ("DT", "Detect"), "snapshot-detector"
+ TOP_DOWN = ("TD", "TopDown"), "snapshot"
+
+ @classmethod
+ def _missing_(cls, value):
+ if isinstance(value, str):
+ value = value.upper()
+ for member in cls:
+ if value in member.aliases:
+ return member
+ return None
+
+ def __repr__(self) -> str:
+ return f"Task.{self.name}"
diff --git a/deeplabcut/pose_estimation_pytorch/utils.py b/deeplabcut/pose_estimation_pytorch/utils.py
new file mode 100644
index 0000000000..7ccacaa904
--- /dev/null
+++ b/deeplabcut/pose_estimation_pytorch/utils.py
@@ -0,0 +1,71 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import os
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from deeplabcut.utils.auxiliaryfunctions import read_plainconfig
+
+
+def create_folder(path_to_folder):
+ """Creates all folders contained in the path.
+
+ Args:
+ path_to_folder: Path to the folder that should be created
+ """
+ if not os.path.exists(path_to_folder):
+ os.makedirs(path_to_folder)
+
+
+def fix_seeds(seed: int) -> None:
+ """
+ Fixes the random seed for python, numpy and pytorch
+
+ Args:
+ seed: the seed to set
+ """
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def resolve_device(model_config: dict) -> str:
+ """Determines which device should be used from the model config
+
+ When the device is set to 'auto':
+ If an Nvidia GPU is available, selects the device as cuda:0.
+ Selects 'mps' if available (on macOS) and the net type is compatible.
+ Otherwise, returns 'cpu'.
+ Otherwise, simply returns the selected device
+
+ Args:
+ model_config: the configuration for the pose model
+
+ Returns:
+ the device on which training should be run
+ """
+ device = model_config["device"]
+ supports_mps = "resnet" in model_config.get("net_type", "resnet")
+
+ if device == "auto":
+ if torch.cuda.is_available():
+ return "cuda"
+ elif supports_mps and torch.backends.mps.is_available():
+ return "mps"
+ return "cpu"
+ return device
diff --git a/deeplabcut/pose_estimation_tensorflow/__init__.py b/deeplabcut/pose_estimation_tensorflow/__init__.py
index 38586cb91f..45560a1114 100644
--- a/deeplabcut/pose_estimation_tensorflow/__init__.py
+++ b/deeplabcut/pose_estimation_tensorflow/__init__.py
@@ -12,6 +12,9 @@
# Licensed under GNU Lesser General Public License v3.0
#
+# Suppress tensorflow warning messages
+import tensorflow as tf
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from deeplabcut.pose_estimation_tensorflow.config import *
from deeplabcut.pose_estimation_tensorflow.datasets import *
@@ -26,6 +29,3 @@
from deeplabcut.pose_estimation_tensorflow.training import *
from deeplabcut.pose_estimation_tensorflow.util import *
from deeplabcut.pose_estimation_tensorflow.visualizemaps import *
-from deeplabcut.pose_estimation_tensorflow.predict_supermodel import (
- video_inference_superanimal,
-)
diff --git a/deeplabcut/pose_estimation_tensorflow/core/evaluate.py b/deeplabcut/pose_estimation_tensorflow/core/evaluate.py
index 33842b613c..3bccdf4a1b 100644
--- a/deeplabcut/pose_estimation_tensorflow/core/evaluate.py
+++ b/deeplabcut/pose_estimation_tensorflow/core/evaluate.py
@@ -14,6 +14,7 @@
import os
from pathlib import Path
from typing import List, Union
+
import numpy as np
import pandas as pd
from tqdm import tqdm
@@ -817,7 +818,7 @@ def evaluate_network(
cfg,
shuffle,
trainFraction,
- training_iterations,
+ trainingsiterations=training_iterations,
modelprefix=modelprefix,
)
print(
diff --git a/deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py b/deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py
index c4120dd8d6..9c34356678 100644
--- a/deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py
+++ b/deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py
@@ -15,10 +15,11 @@
from pathlib import Path
import numpy as np
import pandas as pd
-from scipy.spatial import cKDTree
from tqdm import tqdm
from typing import List
+from deeplabcut.core import crossvalutils
+from deeplabcut.core.crossvalutils import find_closest_neighbors
from deeplabcut.pose_estimation_tensorflow.core.evaluate import (
make_results_file,
keypoint_error,
@@ -27,7 +28,6 @@
)
from deeplabcut.pose_estimation_tensorflow.training import return_train_network_path
from deeplabcut.pose_estimation_tensorflow.config import load_config
-from deeplabcut.pose_estimation_tensorflow.lib import crossvalutils
from deeplabcut.utils import visualization
@@ -53,24 +53,6 @@ def _compute_stats(df):
).stack(level=1)
-def _find_closest_neighbors(xy_true, xy_pred, k=5):
- n_preds = xy_pred.shape[0]
- tree = cKDTree(xy_pred)
- dist, inds = tree.query(xy_true, k=k)
- idx = np.argsort(dist[:, 0])
- neighbors = np.full(len(xy_true), -1, dtype=int)
- picked = set()
- for i, ind in enumerate(inds[idx]):
- for j in ind:
- if j not in picked:
- picked.add(j)
- neighbors[idx[i]] = j
- break
- if len(picked) == n_preds:
- break
- return neighbors
-
-
def _calc_prediction_error(data):
_ = data.pop("metadata", None)
dists = []
@@ -78,7 +60,7 @@ def _calc_prediction_error(data):
gt = np.concatenate(dict_["groundtruth"][1])
xy = np.concatenate(dict_["prediction"]["coordinates"][0])
p = np.concatenate(dict_["prediction"]["confidence"])
- neighbors = _find_closest_neighbors(gt, xy)
+ neighbors = find_closest_neighbors(gt, xy)
found = neighbors != -1
gt2 = gt[found]
xy2 = xy[neighbors[found]]
@@ -280,7 +262,7 @@ def evaluate_multianimal_full(
cfg,
shuffle,
trainFraction,
- training_iterations,
+ trainingsiterations=training_iterations,
modelprefix=modelprefix,
)
print(
@@ -403,7 +385,7 @@ def evaluate_multianimal_full(
# Pick the predictions closest to ground truth,
# rather than the ones the model has most confident in
xy_gt_values = xy_gt.iloc[inds_gt].values
- neighbors = _find_closest_neighbors(
+ neighbors = find_closest_neighbors(
xy_gt_values, xy, k=3
)
found = neighbors != -1
@@ -670,3 +652,7 @@ def evaluate_multianimal_full(
make_results_file(final_result, evaluationfolder, DLCscorer)
os.chdir(str(start_path))
+
+
+# backwards compatibility
+_find_closest_neighbors = find_closest_neighbors
diff --git a/deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py b/deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py
index 79f96c539a..4d98c226a0 100644
--- a/deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py
+++ b/deeplabcut/pose_estimation_tensorflow/core/train_multianimal.py
@@ -61,7 +61,10 @@ def train(
setup_logging()
- cfg = load_config(config_yaml)
+ if isinstance(config_yaml, dict):
+ cfg = config_yaml
+ else:
+ cfg = load_config(config_yaml)
cfg["pseudo_threshold"] = pseudo_threshold
cfg["video_path"] = video_path
@@ -192,8 +195,10 @@ def train(
cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
lr_gen = LearningRate(cfg)
- stats_path = Path(config_yaml).with_name("learning_stats.csv")
- lrf = open(str(stats_path), "w")
+ lrf = None
+ if not isinstance(config_yaml, dict):
+ stats_path = Path(config_yaml).with_name("learning_stats.csv")
+ lrf = open(str(stats_path), "w")
print("Training parameters:")
print(cfg)
@@ -232,26 +237,28 @@ def train(
)
)
- lrf.write(
- "iteration: {}, loss: {}, scmap loss: {}, locref loss: {}, limb loss: {}, lr: {}\n".format(
- it,
- "{0:.4f}".format(cumloss / display_iters),
- "{0:.4f}".format(partloss / display_iters),
- "{0:.4f}".format(locrefloss / display_iters),
- "{0:.4f}".format(pwloss / display_iters),
- current_lr,
+ if lrf:
+ lrf.write(
+ "iteration: {}, loss: {}, scmap loss: {}, locref loss: {}, limb loss: {}, lr: {}\n".format(
+ it,
+ "{0:.4f}".format(cumloss / display_iters),
+ "{0:.4f}".format(partloss / display_iters),
+ "{0:.4f}".format(locrefloss / display_iters),
+ "{0:.4f}".format(pwloss / display_iters),
+ current_lr,
+ )
)
- )
cumloss, partloss, locrefloss, pwloss = 0.0, 0.0, 0.0, 0.0
- lrf.flush()
+ if lrf:
+ lrf.flush()
# Save snapshot
if (it % save_iters == 0 and it != start_iter) or it == max_iter:
model_name = cfg["snapshot_prefix"]
saver.save(sess, model_name, global_step=it)
-
- lrf.close()
+ if lrf:
+ lrf.close()
sess.close()
coord.request_stop()
diff --git a/deeplabcut/pose_estimation_tensorflow/lib/__init__.py b/deeplabcut/pose_estimation_tensorflow/lib/__init__.py
index 6b45344c4b..52c30a86f5 100644
--- a/deeplabcut/pose_estimation_tensorflow/lib/__init__.py
+++ b/deeplabcut/pose_estimation_tensorflow/lib/__init__.py
@@ -8,15 +8,8 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-"""
-DeepLabCut2.0 Toolbox (deeplabcut.org)
-© A. & M. Mathis Labs
-https://github.com/DeepLabCut/DeepLabCut
-Please see AUTHORS for contributors.
-https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
-Licensed under GNU Lesser General Public License v3.0
-
-"""
-
-from deeplabcut.pose_estimation_tensorflow.lib import *
+# imports for backwards compatibility
+import deeplabcut.core.crossvalutils
+import deeplabcut.core.inferenceutils
+import deeplabcut.core.trackingutils
diff --git a/deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py b/deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py
index 706562bfbf..50f31dddf6 100644
--- a/deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py
+++ b/deeplabcut/pose_estimation_tensorflow/lib/crossvalutils.py
@@ -8,460 +8,5 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-
-
-import os
-import pickle
-import shutil
-from collections import defaultdict
-from copy import deepcopy
-from tqdm import tqdm
-
-import networkx as nx
-import numpy as np
-import pandas as pd
-from scipy.spatial import cKDTree
-from sklearn.metrics.cluster import contingency_matrix
-
-from deeplabcut.pose_estimation_tensorflow.lib.inferenceutils import (
- Assembler,
- evaluate_assembly,
- _parse_ground_truth_data,
-)
-from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions
-
-
-def _set_up_evaluation(data):
- params = dict()
- params["joint_names"] = data["metadata"]["all_joints_names"]
- params["num_joints"] = len(params["joint_names"])
- partaffinityfield_graph = data["metadata"]["PAFgraph"]
- params["paf"] = np.arange(len(partaffinityfield_graph))
- params["paf_graph"] = params["paf_links"] = [
- partaffinityfield_graph[l] for l in params["paf"]
- ]
- params["bpts"] = params["ibpts"] = range(params["num_joints"])
- params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
- return params
-
-
-def _form_original_path(path):
- root, filename = os.path.split(path)
- base, ext = os.path.splitext(filename)
- return os.path.join(root, filename.split("c")[0] + ext)
-
-
-def _unsorted_unique(array):
- _, inds = np.unique(array, return_index=True)
- return np.asarray(array)[np.sort(inds)]
-
-
-def _find_closest_neighbors(query, ref, k=3):
- n_preds = ref.shape[0]
- tree = cKDTree(ref)
- dist, inds = tree.query(query, k=k)
- idx = np.argsort(dist[:, 0])
- neighbors = np.full(len(query), -1, dtype=int)
- picked = set()
- for i, ind in enumerate(inds[idx]):
- for j in ind:
- if j not in picked:
- picked.add(j)
- neighbors[idx[i]] = j
- break
- if len(picked) == n_preds:
- break
- return neighbors
-
-
-def _calc_separability(
- vals_left, vals_right, n_bins=101, metric="jeffries", max_sensitivity=False
-):
- if metric not in ("jeffries", "auc"):
- raise ValueError("`metric` should be either 'jeffries' or 'auc'.")
-
- bins = np.linspace(0, 1, n_bins)
- hist_left = np.histogram(vals_left, bins=bins)[0]
- hist_left = hist_left / hist_left.sum()
- hist_right = np.histogram(vals_right, bins=bins)[0]
- hist_right = hist_right / hist_right.sum()
- tpr = np.cumsum(hist_right)
- if metric == "jeffries":
- sep = np.sqrt(
- 2 * (1 - np.sum(np.sqrt(hist_left * hist_right)))
- ) # Jeffries-Matusita distance
- else:
- sep = np.trapz(np.cumsum(hist_left), tpr)
- if max_sensitivity:
- threshold = bins[max(1, np.argmax(tpr > 0))]
- else:
- threshold = bins[np.argmin(1 - np.cumsum(hist_left) + tpr)]
- return sep, threshold
-
-
-def _calc_within_between_pafs(
- data,
- metadata,
- per_edge=True,
- train_set_only=True,
-):
- data = deepcopy(data)
- train_inds = set(metadata["data"]["trainIndices"])
- graph = data["metadata"]["PAFgraph"]
- within_train = defaultdict(list)
- within_test = defaultdict(list)
- between_train = defaultdict(list)
- between_test = defaultdict(list)
- for i, (key, dict_) in enumerate(data.items()):
- if key == "metadata":
- continue
-
- is_train = i in train_inds
- if train_set_only and not is_train:
- continue
-
- df = dict_["groundtruth"][2]
- try:
- df.drop("single", level="individuals", inplace=True)
- except KeyError:
- pass
- bpts = df.index.get_level_values("bodyparts").unique().to_list()
- coords_gt = (
- df.unstack(["individuals", "coords"])
- .reindex(bpts, level="bodyparts")
- .to_numpy()
- .reshape((len(bpts), -1, 2))
- )
- if np.isnan(coords_gt).all():
- continue
-
- coords = dict_["prediction"]["coordinates"][0]
- # Get animal IDs and corresponding indices in the arrays of detections
- lookup = dict()
- for i, (coord, coord_gt) in enumerate(zip(coords, coords_gt)):
- inds = np.flatnonzero(np.all(~np.isnan(coord), axis=1))
- inds_gt = np.flatnonzero(np.all(~np.isnan(coord_gt), axis=1))
- if inds.size and inds_gt.size:
- neighbors = _find_closest_neighbors(coord_gt[inds_gt], coord[inds], k=3)
- found = neighbors != -1
- lookup[i] = dict(zip(inds_gt[found], inds[neighbors[found]]))
-
- costs = dict_["prediction"]["costs"]
- for k, v in costs.items():
- paf = v["m1"]
- mask_within = np.zeros(paf.shape, dtype=bool)
- s, t = graph[k]
- if s not in lookup or t not in lookup:
- continue
- lu_s = lookup[s]
- lu_t = lookup[t]
- common_id = set(lu_s).intersection(lu_t)
- for id_ in common_id:
- mask_within[lu_s[id_], lu_t[id_]] = True
- within_vals = paf[mask_within]
- between_vals = paf[~mask_within]
- if is_train:
- within_train[k].extend(within_vals)
- between_train[k].extend(between_vals)
- else:
- within_test[k].extend(within_vals)
- between_test[k].extend(between_vals)
- if not per_edge:
- within_train = np.concatenate([*within_train.values()])
- within_test = np.concatenate([*within_test.values()])
- between_train = np.concatenate([*between_train.values()])
- between_test = np.concatenate([*between_test.values()])
- return (within_train, within_test), (between_train, between_test)
-
-
-def _benchmark_paf_graphs(
- config,
- inference_cfg,
- data,
- paf_inds,
- greedy=False,
- add_discarded=True,
- identity_only=False,
- calibration_file="",
- oks_sigma=0.1,
- margin=0,
- symmetric_kpts=None,
- split_inds=None,
-):
- metadata = data.pop("metadata")
- multi_bpts_orig = auxfun_multianimal.extractindividualsandbodyparts(config)[2]
- multi_bpts = [j for j in metadata["all_joints_names"] if j in multi_bpts_orig]
- n_multi = len(multi_bpts)
- data_ = {"metadata": metadata}
- for k, v in data.items():
- data_[k] = v["prediction"]
- ass = Assembler(
- data_,
- max_n_individuals=inference_cfg["topktoretain"],
- n_multibodyparts=n_multi,
- greedy=greedy,
- pcutoff=inference_cfg.get("pcutoff", 0.1),
- min_affinity=inference_cfg.get("pafthreshold", 0.1),
- add_discarded=add_discarded,
- identity_only=identity_only,
- )
- if calibration_file:
- ass.calibrate(calibration_file)
-
- params = ass.metadata
- image_paths = params["imnames"]
- bodyparts = params["joint_names"]
- idx = (
- data[image_paths[0]]["groundtruth"][2]
- .unstack("coords")
- .reindex(bodyparts, level="bodyparts")
- .index
- )
- mask_multi = idx.get_level_values("individuals") != "single"
- if not mask_multi.all():
- idx = idx.drop("single", level="individuals")
- individuals = idx.get_level_values("individuals").unique()
- n_individuals = len(individuals)
- map_ = dict(zip(individuals, range(n_individuals)))
-
- # Form ground truth beforehand
- ground_truth = []
- for i, imname in enumerate(image_paths):
- temp = data[imname]["groundtruth"][2].reindex(multi_bpts, level="bodyparts")
- ground_truth.append(temp.to_numpy().reshape((-1, 2)))
- ground_truth = np.stack(ground_truth)
- temp = np.ones((*ground_truth.shape[:2], 3))
- temp[..., :2] = ground_truth
- temp = temp.reshape((temp.shape[0], n_individuals, -1, 3))
- ass_true_dict = _parse_ground_truth_data(temp)
- ids = np.vectorize(map_.get)(idx.get_level_values("individuals").to_numpy())
- ground_truth = np.insert(ground_truth, 2, ids, axis=2)
-
- # Assemble animals on the full set of detections
- paf_inds = sorted(paf_inds, key=len)
- n_graphs = len(paf_inds)
- all_scores = []
- all_metrics = []
- all_assemblies = []
- for j, paf in enumerate(paf_inds, start=1):
- print(f"Graph {j}|{n_graphs}")
- ass.paf_inds = paf
- ass.assemble()
- all_assemblies.append((ass.assemblies, ass.unique, ass.metadata["imnames"]))
- if split_inds is not None:
- oks = []
-
- # get the indices of the images in the training set
- dataset_idx = [data[image_name]["index"] for image_name in image_paths]
- for inds in split_inds:
- ass_gt = {
- k: v for k, v in ass_true_dict.items() if dataset_idx[k] in inds
- }
- ass_pred = {
- k: v for k, v in ass.assemblies.items() if dataset_idx[k] in inds
- }
-
- oks.append(
- evaluate_assembly(
- ass_pred,
- ass_gt,
- oks_sigma,
- margin=margin,
- symmetric_kpts=symmetric_kpts,
- greedy_matching=inference_cfg.get("greedy_oks", False),
- )
- )
- else:
- oks = evaluate_assembly(
- ass.assemblies,
- ass_true_dict,
- oks_sigma,
- margin=margin,
- symmetric_kpts=symmetric_kpts,
- greedy_matching=inference_cfg.get("greedy_oks", False),
- )
- all_metrics.append(oks)
- scores = np.full((len(image_paths), 2), np.nan)
- for i, imname in enumerate(tqdm(image_paths)):
- gt = ground_truth[i]
- gt = gt[~np.isnan(gt).any(axis=1)]
- if len(np.unique(gt[:, 2])) < 2: # Only consider frames with 2+ animals
- continue
-
- # Count the number of unassembled bodyparts
- n_dets = len(gt)
- animals = ass.assemblies.get(i)
- if animals is None:
- if n_dets:
- scores[i, 0] = 1
- else:
- animals = [
- np.c_[animal.data, np.ones(animal.data.shape[0]) * n]
- for n, animal in enumerate(animals)
- ]
- hyp = np.concatenate(animals)
- hyp = hyp[~np.isnan(hyp).any(axis=1)]
- scores[i, 0] = max(0, (n_dets - hyp.shape[0]) / n_dets)
- neighbors = _find_closest_neighbors(gt[:, :2], hyp[:, :2])
- valid = neighbors != -1
- id_gt = gt[valid, 2]
- id_hyp = hyp[neighbors[valid], -1]
- mat = contingency_matrix(id_gt, id_hyp)
- purity = mat.max(axis=0).sum() / mat.sum()
- scores[i, 1] = purity
- all_scores.append((scores, paf))
-
- dfs = []
- for score, inds in all_scores:
- df = pd.DataFrame(score, columns=["miss", "purity"])
- df["ngraph"] = len(inds)
- dfs.append(df)
- big_df = pd.concat(dfs)
- group = big_df.groupby("ngraph")
- return (all_scores, group.agg(["mean", "std"]).T, all_metrics, all_assemblies)
-
-
-def _get_n_best_paf_graphs(
- data,
- metadata,
- full_graph,
- n_graphs=10,
- root=None,
- which="best",
- ignore_inds=None,
- metric="auc",
-):
- if which not in ("best", "worst"):
- raise ValueError('`which` must be either "best" or "worst"')
-
- (within_train, _), (between_train, _) = _calc_within_between_pafs(
- data,
- metadata,
- train_set_only=True,
- )
- # Handle unlabeled bodyparts...
- existing_edges = set(k for k, v in within_train.items() if v)
- if ignore_inds is not None:
- existing_edges = existing_edges.difference(ignore_inds)
- existing_edges = list(existing_edges)
-
- if not any(between_train.values()):
- # Only 1 animal, let us return the full graph indices only
- return ([existing_edges], dict(zip(existing_edges, [0] * len(existing_edges))))
-
- scores, _ = zip(
- *[
- _calc_separability(between_train[n], within_train[n], metric=metric)
- for n in existing_edges
- ]
- )
-
- # Find minimal skeleton
- G = nx.Graph()
- for edge, score in zip(existing_edges, scores):
- if np.isfinite(score):
- G.add_edge(*full_graph[edge], weight=score)
- if which == "best":
- order = np.asarray(existing_edges)[np.argsort(scores)[::-1]]
- if root is None:
- root = []
- for edge in nx.maximum_spanning_edges(G, data=False):
- root.append(full_graph.index(sorted(edge)))
- else:
- order = np.asarray(existing_edges)[np.argsort(scores)]
- if root is None:
- root = []
- for edge in nx.minimum_spanning_edges(G, data=False):
- root.append(full_graph.index(sorted(edge)))
-
- n_edges = len(existing_edges) - len(root)
- lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:]
- order = order[np.isin(order, root, invert=True)]
- paf_inds = [root]
- for length in lengths:
- paf_inds.append(root + list(order[:length]))
- return paf_inds, dict(zip(existing_edges, scores))
-
-
-def cross_validate_paf_graphs(
- config,
- inference_config,
- full_data_file,
- metadata_file,
- output_name="",
- pcutoff=0.1,
- oks_sigma=0.1,
- margin=0,
- greedy=False,
- add_discarded=True,
- calibrate=False,
- overwrite_config=True,
- n_graphs=10,
- paf_inds=None,
- symmetric_kpts=None,
-):
- cfg = auxiliaryfunctions.read_config(config)
- inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config)
- inf_cfg_temp = inf_cfg.copy()
- inf_cfg_temp["pcutoff"] = pcutoff
-
- with open(full_data_file, "rb") as file:
- data = pickle.load(file)
- with open(metadata_file, "rb") as file:
- metadata = pickle.load(file)
-
- params = _set_up_evaluation(data)
- to_ignore = auxfun_multianimal.filter_unwanted_paf_connections(
- cfg, params["paf_graph"]
- )
- best_graphs = _get_n_best_paf_graphs(
- data,
- metadata,
- params["paf_graph"],
- ignore_inds=to_ignore,
- n_graphs=n_graphs,
- )
- paf_scores = best_graphs[1]
- if paf_inds is None:
- paf_inds = best_graphs[0]
-
- if calibrate:
- trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
- calibration_file = os.path.join(
- cfg["project_path"],
- str(trainingsetfolder),
- "CollectedData_" + cfg["scorer"] + ".h5",
- )
- else:
- calibration_file = ""
-
- results = _benchmark_paf_graphs(
- cfg,
- inf_cfg_temp,
- data,
- paf_inds,
- greedy,
- add_discarded,
- oks_sigma=oks_sigma,
- margin=margin,
- symmetric_kpts=symmetric_kpts,
- calibration_file=calibration_file,
- split_inds=[
- metadata["data"]["trainIndices"],
- metadata["data"]["testIndices"],
- ],
- )
- # Select optimal PAF graph
- df = results[1]
- size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"])
- pose_config = inference_config.replace("inference_cfg", "pose_cfg")
- if not overwrite_config:
- shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml"))
- inds = list(paf_inds[size_opt])
- auxiliaryfunctions.edit_config(
- pose_config, {"paf_best": [int(ind) for ind in inds]}
- )
- if output_name:
- with open(output_name, "wb") as file:
- pickle.dump([results], file)
- return results[:3], paf_scores, results[3][size_opt]
+"""Backwards compatibility"""
+from deeplabcut.core.crossvalutils import *
diff --git a/deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py b/deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
index c9e5456823..889311441c 100644
--- a/deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
+++ b/deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
@@ -8,1082 +8,5 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-
-import heapq
-import itertools
-import multiprocessing
-import networkx as nx
-import numpy as np
-import operator
-import pandas as pd
-import pickle
-import warnings
-from collections import defaultdict
-from dataclasses import dataclass
-from math import sqrt, erf
-from scipy.optimize import linear_sum_assignment
-from scipy.spatial import cKDTree
-from scipy.spatial.distance import pdist, cdist
-from scipy.special import softmax
-from scipy.stats import gaussian_kde, chi2
-from tqdm import tqdm
-from typing import Tuple
-
-
-def _conv_square_to_condensed_indices(ind_row, ind_col, n):
- if ind_row == ind_col:
- raise ValueError("There are no diagonal elements in condensed matrices.")
-
- if ind_row < ind_col:
- ind_row, ind_col = ind_col, ind_row
- return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col
-
-
-Position = Tuple[float, float]
-
-
-@dataclass(frozen=True)
-class Joint:
- pos: Position
- confidence: float = 1.0
- label: int = None
- idx: int = None
- group: int = -1
-
-
-class Link:
- def __init__(self, j1, j2, affinity=1):
- self.j1 = j1
- self.j2 = j2
- self.affinity = affinity
- self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2)
-
- def __repr__(self):
- return (
- f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}"
- )
-
- @property
- def confidence(self):
- return self.j1.confidence * self.j2.confidence
-
- @property
- def idx(self):
- return self.j1.idx, self.j2.idx
-
- @property
- def length(self):
- return self._length
-
- @length.setter
- def length(self, length):
- self._length = length
-
- def to_vector(self):
- return [*self.j1.pos, *self.j2.pos]
-
-
-class Assembly:
- def __init__(self, size):
- self.data = np.full((size, 4), np.nan)
- self.confidence = 0 # 0 by default, overwritten otherwise with `add_joint`
- self._affinity = 0
- self._links = []
- self._visible = set()
- self._idx = set()
- self._dict = dict()
-
- def __len__(self):
- return len(self._visible)
-
- def __contains__(self, assembly):
- return bool(self._visible.intersection(assembly._visible))
-
- def __add__(self, other):
- if other in self:
- raise ValueError("Assemblies contain shared joints.")
-
- assembly = Assembly(self.data.shape[0])
- for link in self._links + other._links:
- assembly.add_link(link)
- return assembly
-
- @classmethod
- def from_array(cls, array):
- n_bpts, n_cols = array.shape
- ass = cls(size=n_bpts)
- ass.data[:, :n_cols] = array
- visible = np.flatnonzero(~np.isnan(array).any(axis=1))
- if n_cols < 3: # Only xy coordinates are being set
- ass.data[visible, 2] = 1 # Set detection confidence to 1
- ass._visible.update(visible)
- return ass
-
- @property
- def xy(self):
- return self.data[:, :2]
-
- @property
- def extent(self):
- bbox = np.empty(4)
- bbox[:2] = np.nanmin(self.xy, axis=0)
- bbox[2:] = np.nanmax(self.xy, axis=0)
- return bbox
-
- @property
- def area(self):
- x1, y1, x2, y2 = self.extent
- return (x2 - x1) * (y2 - y1)
-
- @property
- def confidence(self):
- return np.nanmean(self.data[:, 2])
-
- @confidence.setter
- def confidence(self, confidence):
- self.data[:, 2] = confidence
-
- @property
- def soft_identity(self):
- data = self.data[~np.isnan(self.data).any(axis=1)]
- unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True)
- avg = np.bincount(idx, weights=data[:, 2]) / cnt
- soft = softmax(avg)
- return dict(zip(unq.astype(int), soft))
-
- @property
- def affinity(self):
- n_links = self.n_links
- if not n_links:
- return 0
- return self._affinity / n_links
-
- @property
- def n_links(self):
- return len(self._links)
-
- def intersection_with(self, other):
- x11, y11, x21, y21 = self.extent
- x12, y12, x22, y22 = other.extent
- x1 = max(x11, x12)
- y1 = max(y11, y12)
- x2 = min(x21, x22)
- y2 = min(y21, y22)
- if x2 < x1 or y2 < y1:
- return 0
- ll = np.array([x1, y1])
- ur = np.array([x2, y2])
- xy1 = self.xy[~np.isnan(self.xy).any(axis=1)]
- xy2 = other.xy[~np.isnan(other.xy).any(axis=1)]
- in1 = np.all((xy1 >= ll) & (xy1 <= ur), axis=1).sum()
- in2 = np.all((xy2 >= ll) & (xy2 <= ur), axis=1).sum()
- return min(in1 / len(self), in2 / len(other))
-
- def add_joint(self, joint):
- if joint.label in self._visible or joint.label is None:
- return False
- self.data[joint.label] = *joint.pos, joint.confidence, joint.group
- self._visible.add(joint.label)
- self._idx.add(joint.idx)
- return True
-
- def remove_joint(self, joint):
- if joint.label not in self._visible:
- return False
- self.data[joint.label] = np.nan
- self._visible.remove(joint.label)
- self._idx.remove(joint.idx)
- return True
-
- def add_link(self, link, store_dict=False):
- if store_dict:
- # Selective copy; deepcopy is >5x slower
- self._dict = {
- "data": self.data.copy(),
- "_affinity": self._affinity,
- "_links": self._links.copy(),
- "_visible": self._visible.copy(),
- "_idx": self._idx.copy(),
- }
- i1, i2 = link.idx
- if i1 in self._idx and i2 in self._idx:
- self._affinity += link.affinity
- self._links.append(link)
- return False
- if link.j1.label in self._visible and link.j2.label in self._visible:
- return False
- self.add_joint(link.j1)
- self.add_joint(link.j2)
- self._affinity += link.affinity
- self._links.append(link)
- return True
-
- def calc_pairwise_distances(self):
- return pdist(self.xy, metric="sqeuclidean")
-
-
-class Assembler:
- def __init__(
- self,
- data,
- *,
- max_n_individuals,
- n_multibodyparts,
- graph=None,
- paf_inds=None,
- greedy=False,
- pcutoff=0.1,
- min_affinity=0.05,
- min_n_links=2,
- max_overlap=0.8,
- identity_only=False,
- nan_policy="little",
- force_fusion=False,
- add_discarded=False,
- window_size=0,
- method="m1",
- ):
- self.data = data
- self.metadata = self.parse_metadata(self.data)
- self.max_n_individuals = max_n_individuals
- self.n_multibodyparts = n_multibodyparts
- self.n_uniquebodyparts = self.n_keypoints - n_multibodyparts
- self.greedy = greedy
- self.pcutoff = pcutoff
- self.min_affinity = min_affinity
- self.min_n_links = min_n_links
- self.max_overlap = max_overlap
- self._has_identity = "identity" in self[0]
- if identity_only and not self._has_identity:
- warnings.warn(
- "The network was not trained with identity; setting `identity_only` to False."
- )
- self.identity_only = identity_only & self._has_identity
- self.nan_policy = nan_policy
- self.force_fusion = force_fusion
- self.add_discarded = add_discarded
- self.window_size = window_size
- self.method = method
- self.graph = graph or self.metadata["paf_graph"]
- self.paf_inds = paf_inds or self.metadata["paf"]
- self._gamma = 0.01
- self._trees = dict()
- self.safe_edge = False
- self._kde = None
- self.assemblies = dict()
- self.unique = dict()
-
- def __getitem__(self, item):
- return self.data[self.metadata["imnames"][item]]
-
- @property
- def n_keypoints(self):
- return self.metadata["num_joints"]
-
- def calibrate(self, train_data_file):
- df = pd.read_hdf(train_data_file)
- try:
- df.drop("single", level="individuals", axis=1, inplace=True)
- except KeyError:
- pass
- n_bpts = len(df.columns.get_level_values("bodyparts").unique())
- if n_bpts == 1:
- warnings.warn("There is only one keypoint; skipping calibration...")
- return
-
- xy = df.to_numpy().reshape((-1, n_bpts, 2))
- frac_valid = np.mean(~np.isnan(xy), axis=(1, 2))
- # Only keeps skeletons that are more than 90% complete
- xy = xy[frac_valid >= 0.9]
- if not xy.size:
- warnings.warn("No complete poses were found. Skipping calibration...")
- return
-
- # TODO Normalize dists by longest length?
- # TODO Smarter imputation technique (Bayesian? Grassmann averages?)
- dists = np.vstack([pdist(data, "sqeuclidean") for data in xy])
- mu = np.nanmean(dists, axis=0)
- missing = np.isnan(dists)
- dists = np.where(missing, mu, dists)
- try:
- kde = gaussian_kde(dists.T)
- kde.mean = mu
- self._kde = kde
- self.safe_edge = True
- except np.linalg.LinAlgError:
- # Covariance matrix estimation fails due to numerical singularities
- warnings.warn(
- "The assembler could not be robustly calibrated. Continuing without it..."
- )
-
- def calc_assembly_mahalanobis_dist(
- self, assembly, return_proba=False, nan_policy="little"
- ):
- if self._kde is None:
- raise ValueError("Assembler should be calibrated first with training data.")
-
- dists = assembly.calc_pairwise_distances() - self._kde.mean
- mask = np.isnan(dists)
- # Distance is undefined if the assembly is empty
- if not len(assembly) or mask.all():
- if return_proba:
- return np.inf, 0
- return np.inf
-
- if nan_policy == "little":
- inds = np.flatnonzero(~mask)
- dists = dists[inds]
- inv_cov = self._kde.inv_cov[np.ix_(inds, inds)]
- # Correct distance to account for missing observations
- factor = self._kde.d / len(inds)
- else:
- # Alternatively, reduce contribution of missing values to the Mahalanobis
- # distance to zero by substituting the corresponding means.
- dists[mask] = 0
- mask.fill(False)
- inv_cov = self._kde.inv_cov
- factor = 1
- dot = dists @ inv_cov
- mahal = factor * sqrt(np.sum((dot * dists), axis=-1))
- if return_proba:
- proba = 1 - chi2.cdf(mahal, np.sum(~mask))
- return mahal, proba
- return mahal
-
- def calc_link_probability(self, link):
- if self._kde is None:
- raise ValueError("Assembler should be calibrated first with training data.")
-
- i = link.j1.label
- j = link.j2.label
- ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts)
- mu = self._kde.mean[ind]
- sigma = self._kde.covariance[ind, ind]
- z = (link.length**2 - mu) / sigma
- return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2))))
-
- @staticmethod
- def _flatten_detections(data_dict):
- ind = 0
- coordinates = data_dict["coordinates"][0]
- confidence = data_dict["confidence"]
- ids = data_dict.get("identity", None)
- if ids is None:
- ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence]
- else:
- ids = [arr.argmax(axis=1) for arr in ids]
- for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)):
- if not np.any(coords):
- continue
- for xy, p, g in zip(coords, conf, id_):
- joint = Joint(tuple(xy), p.item(), i, ind, g)
- ind += 1
- yield joint
-
- def extract_best_links(self, joints_dict, costs, trees=None):
- links = []
- for ind in self.paf_inds:
- s, t = self.graph[ind]
- dets_s = joints_dict.get(s, None)
- dets_t = joints_dict.get(t, None)
- if dets_s is None or dets_t is None:
- continue
- if ind not in costs:
- continue
- lengths = costs[ind]["distance"]
- if np.isinf(lengths).all():
- continue
- aff = costs[ind][self.method].copy()
- aff[np.isnan(aff)] = 0
-
- if trees:
- vecs = np.vstack(
- [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t]
- )
- dists = []
- for n, tree in enumerate(trees, start=1):
- d, _ = tree.query(vecs)
- dists.append(np.exp(-self._gamma * n * d))
- w = np.mean(dists, axis=0)
- aff *= w.reshape(aff.shape)
-
- if self.greedy:
- conf = np.asarray(
- [
- [det_s.confidence * det_t.confidence for det_t in dets_t]
- for det_s in dets_s
- ]
- )
- rows, cols = np.where(
- (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity)
- )
- candidates = sorted(
- zip(rows, cols, aff[rows, cols], lengths[rows, cols]),
- key=lambda x: x[2],
- reverse=True,
- )
- i_seen = set()
- j_seen = set()
- for i, j, w, l in candidates:
- if i not in i_seen and j not in j_seen:
- i_seen.add(i)
- j_seen.add(j)
- links.append(Link(dets_s[i], dets_t[j], w))
- if len(i_seen) == self.max_n_individuals:
- break
- else: # Optimal keypoint pairing
- inds_s = sorted(
- range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True
- )[: self.max_n_individuals]
- inds_t = sorted(
- range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True
- )[: self.max_n_individuals]
- keep_s = [
- ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff
- ]
- keep_t = [
- ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff
- ]
- aff = aff[np.ix_(keep_s, keep_t)]
- rows, cols = linear_sum_assignment(aff, maximize=True)
- for row, col in zip(rows, cols):
- w = aff[row, col]
- if w >= self.min_affinity:
- links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w))
- return links
-
- def _fill_assembly(self, assembly, lookup, assembled, safe_edge, nan_policy):
- stack = []
- visited = set()
- tabu = []
- counter = itertools.count()
-
- def push_to_stack(i):
- for j, link in lookup[i].items():
- if j in assembly._idx:
- continue
- if link.idx in visited:
- continue
- heapq.heappush(stack, (-link.affinity, next(counter), link))
- visited.add(link.idx)
-
- for idx in assembly._idx:
- push_to_stack(idx)
-
- while stack and len(assembly) < self.n_multibodyparts:
- _, _, best = heapq.heappop(stack)
- i, j = best.idx
- if i in assembly._idx:
- new_ind = j
- elif j in assembly._idx:
- new_ind = i
- else:
- continue
- if new_ind in assembled:
- continue
- if safe_edge:
- d_old = self.calc_assembly_mahalanobis_dist(
- assembly, nan_policy=nan_policy
- )
- success = assembly.add_link(best, store_dict=True)
- if not success:
- assembly._dict = dict()
- continue
- d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy)
- if d < d_old:
- push_to_stack(new_ind)
- try:
- _, _, link = heapq.heappop(tabu)
- heapq.heappush(stack, (-link.affinity, next(counter), link))
- except IndexError:
- pass
- else:
- heapq.heappush(tabu, (d - d_old, next(counter), best))
- assembly.__dict__.update(assembly._dict)
- assembly._dict = dict()
- else:
- assembly.add_link(best)
- push_to_stack(new_ind)
-
- def build_assemblies(self, links):
- lookup = defaultdict(dict)
- for link in links:
- i, j = link.idx
- lookup[i][j] = link
- lookup[j][i] = link
-
- assemblies = []
- assembled = set()
-
- # Fill the subsets with unambiguous, complete individuals
- G = nx.Graph([link.idx for link in links])
- for chain in nx.connected_components(G):
- if len(chain) == self.n_multibodyparts:
- edges = [tuple(sorted(edge)) for edge in G.edges(chain)]
- assembly = Assembly(self.n_multibodyparts)
- for link in links:
- i, j = link.idx
- if (i, j) in edges:
- success = assembly.add_link(link)
- if success:
- lookup[i].pop(j)
- lookup[j].pop(i)
- assembled.update(assembly._idx)
- assemblies.append(assembly)
-
- if len(assemblies) == self.max_n_individuals:
- return assemblies, assembled
-
- for link in sorted(links, key=lambda x: x.affinity, reverse=True):
- if any(i in assembled for i in link.idx):
- continue
- assembly = Assembly(self.n_multibodyparts)
- assembly.add_link(link)
- self._fill_assembly(
- assembly, lookup, assembled, self.safe_edge, self.nan_policy
- )
- for link in assembly._links:
- i, j = link.idx
- lookup[i].pop(j)
- lookup[j].pop(i)
- assembled.update(assembly._idx)
- assemblies.append(assembly)
-
- # Fuse superfluous assemblies
- n_extra = len(assemblies) - self.max_n_individuals
- if n_extra > 0:
- if self.safe_edge:
- ds_old = [
- self.calc_assembly_mahalanobis_dist(assembly)
- for assembly in assemblies
- ]
- while len(assemblies) > self.max_n_individuals:
- ds = []
- for i, j in itertools.combinations(range(len(assemblies)), 2):
- if assemblies[j] not in assemblies[i]:
- temp = assemblies[i] + assemblies[j]
- d = self.calc_assembly_mahalanobis_dist(temp)
- delta = d - max(ds_old[i], ds_old[j])
- ds.append((i, j, delta, d, temp))
- if not ds:
- break
- min_ = sorted(ds, key=lambda x: x[2])
- i, j, delta, d, new = min_[0]
- if delta < 0 or len(min_) == 1:
- assemblies[i] = new
- assemblies.pop(j)
- ds_old[i] = d
- ds_old.pop(j)
- else:
- break
- elif self.force_fusion:
- assemblies = sorted(assemblies, key=len)
- for nrow in range(n_extra):
- assembly = assemblies[nrow]
- candidates = [a for a in assemblies[nrow:] if assembly not in a]
- if not candidates:
- continue
- if len(candidates) == 1:
- candidate = candidates[0]
- else:
- dists = []
- for cand in candidates:
- d = cdist(assembly.xy, cand.xy)
- dists.append(np.nanmin(d))
- candidate = candidates[np.argmin(dists)]
- ind = assemblies.index(candidate)
- assemblies[ind] += assembly
- else:
- store = dict()
- for assembly in assemblies:
- if len(assembly) != self.n_multibodyparts:
- for i in assembly._idx:
- store[i] = assembly
- used = [link for assembly in assemblies for link in assembly._links]
- unconnected = [link for link in links if link not in used]
- for link in unconnected:
- i, j = link.idx
- try:
- if store[j] not in store[i]:
- temp = store[i] + store[j]
- store[i].__dict__.update(temp.__dict__)
- assemblies.remove(store[j])
- for idx in store[j]._idx:
- store[idx] = store[i]
- except KeyError:
- pass
-
- # Second pass without edge safety
- for assembly in assemblies:
- if len(assembly) != self.n_multibodyparts:
- self._fill_assembly(assembly, lookup, assembled, False, "")
- assembled.update(assembly._idx)
-
- return assemblies, assembled
-
- def _assemble(self, data_dict, ind_frame):
- joints = list(self._flatten_detections(data_dict))
- if not joints:
- return None, None
-
- bag = defaultdict(list)
- for joint in joints:
- bag[joint.label].append(joint)
-
- assembled = set()
-
- if self.n_uniquebodyparts:
- unique = np.full((self.n_uniquebodyparts, 3), np.nan)
- for n, ind in enumerate(range(self.n_multibodyparts, self.n_keypoints)):
- dets = bag[ind]
- if not dets:
- continue
- if len(dets) > 1:
- det = max(dets, key=lambda x: x.confidence)
- else:
- det = dets[0]
- # Mark the unique body parts as assembled anyway so
- # they are not used later on to fill assemblies.
- assembled.update(d.idx for d in dets)
- if det.confidence <= self.pcutoff and not self.add_discarded:
- continue
- unique[n] = *det.pos, det.confidence
- if np.isnan(unique).all():
- unique = None
- else:
- unique = None
-
- if not any(i in bag for i in range(self.n_multibodyparts)):
- return None, unique
-
- if self.n_multibodyparts == 1:
- assemblies = []
- for joint in bag[0]:
- if joint.confidence >= self.pcutoff:
- ass = Assembly(self.n_multibodyparts)
- ass.add_joint(joint)
- assemblies.append(ass)
- return assemblies, unique
-
- if self.max_n_individuals == 1:
- get_attr = operator.attrgetter("confidence")
- ass = Assembly(self.n_multibodyparts)
- for ind in range(self.n_multibodyparts):
- joints = bag[ind]
- if not joints:
- continue
- ass.add_joint(max(joints, key=get_attr))
- return [ass], unique
-
- if self.identity_only:
- assemblies = []
- get_attr = operator.attrgetter("group")
- temp = sorted(
- (joint for joint in joints if np.isfinite(joint.confidence)),
- key=get_attr,
- )
- groups = itertools.groupby(temp, get_attr)
- for _, group in groups:
- ass = Assembly(self.n_multibodyparts)
- for joint in sorted(group, key=lambda x: x.confidence, reverse=True):
- if (
- joint.confidence >= self.pcutoff
- and joint.label < self.n_multibodyparts
- ):
- ass.add_joint(joint)
- if len(ass):
- assemblies.append(ass)
- assembled.update(ass._idx)
- else:
- trees = []
- for j in range(1, self.window_size + 1):
- tree = self._trees.get(ind_frame - j, None)
- if tree is not None:
- trees.append(tree)
-
- links = self.extract_best_links(bag, data_dict["costs"], trees)
- if self._kde:
- for link in links[::-1]:
- p = max(self.calc_link_probability(link), 0.001)
- link.affinity *= p
- if link.affinity < self.min_affinity:
- links.remove(link)
-
- if self.window_size >= 1 and links:
- # Store selected edges for subsequent frames
- vecs = np.vstack([link.to_vector() for link in links])
- self._trees[ind_frame] = cKDTree(vecs)
-
- assemblies, assembled_ = self.build_assemblies(links)
- assembled.update(assembled_)
-
- # Remove invalid assemblies
- discarded = set(
- joint
- for joint in joints
- if joint.idx not in assembled and np.isfinite(joint.confidence)
- )
- for assembly in assemblies[::-1]:
- if 0 < assembly.n_links < self.min_n_links or not len(assembly):
- for link in assembly._links:
- discarded.update((link.j1, link.j2))
- assemblies.remove(assembly)
- if 0 < self.max_overlap < 1: # Non-maximum pose suppression
- if self._kde is not None:
- scores = [
- -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies
- ]
- else:
- scores = [ass._affinity for ass in assemblies]
- lst = list(zip(scores, assemblies))
- assemblies = []
- while lst:
- temp = max(lst, key=lambda x: x[0])
- lst.remove(temp)
- assemblies.append(temp[1])
- for pair in lst[::-1]:
- if temp[1].intersection_with(pair[1]) >= self.max_overlap:
- lst.remove(pair)
- if len(assemblies) > self.max_n_individuals:
- assemblies = sorted(assemblies, key=len, reverse=True)
- for assembly in assemblies[self.max_n_individuals :]:
- for link in assembly._links:
- discarded.update((link.j1, link.j2))
- assemblies = assemblies[: self.max_n_individuals]
-
- if self.add_discarded and discarded:
- # Fill assemblies with unconnected body parts
- for joint in sorted(discarded, key=lambda x: x.confidence, reverse=True):
- if self.safe_edge:
- for assembly in assemblies:
- if joint.label in assembly._visible:
- continue
- d_old = self.calc_assembly_mahalanobis_dist(assembly)
- assembly.add_joint(joint)
- d = self.calc_assembly_mahalanobis_dist(assembly)
- if d < d_old:
- break
- assembly.remove_joint(joint)
- else:
- dists = []
- for i, assembly in enumerate(assemblies):
- if joint.label in assembly._visible:
- continue
- d = cdist(assembly.xy, np.atleast_2d(joint.pos))
- dists.append((i, np.nanmin(d)))
- if not dists:
- continue
- min_ = sorted(dists, key=lambda x: x[1])
- ind, _ = min_[0]
- assemblies[ind].add_joint(joint)
-
- return assemblies, unique
-
- def assemble(self, chunk_size=1, n_processes=None):
- self.assemblies = dict()
- self.unique = dict()
- # Spawning (rather than forking) multiple processes does not
- # work nicely with the GUI or interactive sessions.
- # In that case, we fall back to the serial assembly.
- if chunk_size == 0 or multiprocessing.get_start_method() == "spawn":
- for i, data_dict in enumerate(tqdm(self)):
- assemblies, unique = self._assemble(data_dict, i)
- if assemblies:
- self.assemblies[i] = assemblies
- if unique is not None:
- self.unique[i] = unique
- else:
- global wrapped # Hack to make the function pickable
-
- def wrapped(i):
- return i, self._assemble(self[i], i)
-
- n_frames = len(self.metadata["imnames"])
- with multiprocessing.Pool(n_processes) as p:
- with tqdm(total=n_frames) as pbar:
- for i, (assemblies, unique) in p.imap_unordered(
- wrapped, range(n_frames), chunksize=chunk_size
- ):
- if assemblies:
- self.assemblies[i] = assemblies
- if unique is not None:
- self.unique[i] = unique
- pbar.update()
-
- def from_pickle(self, pickle_path):
- with open(pickle_path, "rb") as file:
- data = pickle.load(file)
- self.unique = data.pop("single", {})
- self.assemblies = data
-
- @staticmethod
- def parse_metadata(data):
- params = dict()
- params["joint_names"] = data["metadata"]["all_joints_names"]
- params["num_joints"] = len(params["joint_names"])
- params["paf_graph"] = data["metadata"]["PAFgraph"]
- params["paf"] = data["metadata"].get(
- "PAFinds", np.arange(len(params["joint_names"]))
- )
- params["bpts"] = params["ibpts"] = range(params["num_joints"])
- params["imnames"] = [fn for fn in list(data) if fn != "metadata"]
- return params
-
- def to_h5(self, output_name):
- data = np.full(
- (
- len(self.metadata["imnames"]),
- self.max_n_individuals,
- self.n_multibodyparts,
- 4,
- ),
- fill_value=np.nan,
- )
- for ind, assemblies in self.assemblies.items():
- for n, assembly in enumerate(assemblies):
- data[ind, n] = assembly.data
- index = pd.MultiIndex.from_product(
- [
- ["scorer"],
- map(str, range(self.max_n_individuals)),
- map(str, range(self.n_multibodyparts)),
- ["x", "y", "likelihood"],
- ],
- names=["scorer", "individuals", "bodyparts", "coords"],
- )
- temp = data[..., :3].reshape((data.shape[0], -1))
- df = pd.DataFrame(temp, columns=index)
- df.to_hdf(output_name, key="ass")
-
- def to_pickle(self, output_name):
- data = dict()
- for ind, assemblies in self.assemblies.items():
- data[ind] = [ass.data for ass in assemblies]
- if self.unique:
- data["single"] = self.unique
- with open(output_name, "wb") as file:
- pickle.dump(data, file, pickle.HIGHEST_PROTOCOL)
-
-
-def calc_object_keypoint_similarity(
- xy_pred,
- xy_true,
- sigma,
- margin=0,
- symmetric_kpts=None,
-):
- visible_gt = ~np.isnan(xy_true).all(axis=1)
- if visible_gt.sum() < 2: # At least 2 points needed to calculate scale
- return np.nan
- true = xy_true[visible_gt]
- scale_squared = np.product(np.ptp(true, axis=0) + np.spacing(1) + margin * 2)
- if np.isclose(scale_squared, 0):
- return np.nan
- k_squared = (2 * sigma) ** 2
- denom = 2 * scale_squared * k_squared
- if symmetric_kpts is None:
- pred = xy_pred[visible_gt]
- pred[np.isnan(pred)] = np.inf
- dist_squared = np.sum((pred - true) ** 2, axis=1)
- oks = np.exp(-dist_squared / denom)
- return np.mean(oks)
- else:
- oks = []
- xy_preds = [xy_pred]
- combos = (
- pair
- for l in range(len(symmetric_kpts))
- for pair in itertools.combinations(symmetric_kpts, l + 1)
- )
- for pairs in combos:
- # Swap corresponding keypoints
- tmp = xy_pred.copy()
- for pair in pairs:
- tmp[pair, :] = tmp[pair[::-1], :]
- xy_preds.append(tmp)
- for xy_pred in xy_preds:
- pred = xy_pred[visible_gt]
- pred[np.isnan(pred)] = np.inf
- dist_squared = np.sum((pred - true) ** 2, axis=1)
- oks.append(np.mean(np.exp(-dist_squared / denom)))
- return max(oks)
-
-
-def match_assemblies(
- ass_pred, ass_true, sigma, margin=0, symmetric_kpts=None, greedy_matching=False
-):
- # Only consider assemblies of at least two keypoints
- ass_pred = [a for a in ass_pred if len(a) > 1]
- ass_true = [a for a in ass_true if len(a) > 1]
-
- matched = []
-
- # Greedy assembly matching like in pycocotools
- if greedy_matching:
- inds_true = list(range(len(ass_true)))
- inds_pred = np.argsort(
- [ins.affinity if ins.n_links else ins.confidence for ins in ass_pred]
- )[::-1]
- for ind_pred in inds_pred:
- xy_pred = ass_pred[ind_pred].xy
- oks = []
- for ind_true in inds_true:
- xy_true = ass_true[ind_true].xy
- oks.append(
- calc_object_keypoint_similarity(
- xy_pred,
- xy_true,
- sigma,
- margin,
- symmetric_kpts,
- )
- )
- if np.all(np.isnan(oks)):
- continue
- ind_best = np.nanargmax(oks)
- ind_true_best = inds_true.pop(ind_best)
- matched.append((ass_pred[ind_pred], ass_true[ind_true_best], oks[ind_best]))
- if not inds_true:
- break
-
- # Global rather than greedy assembly matching
- else:
- mat = np.zeros((len(ass_pred), len(ass_true)))
- for i, a_pred in enumerate(ass_pred):
- for j, a_true in enumerate(ass_true):
- oks = calc_object_keypoint_similarity(
- a_pred.xy,
- a_true.xy,
- sigma,
- margin,
- symmetric_kpts,
- )
- if ~np.isnan(oks):
- mat[i, j] = oks
- rows, cols = linear_sum_assignment(mat, maximize=True)
- inds_true = list(range(len(ass_true)))
- for row, col in zip(rows, cols):
- matched.append((ass_pred[row], ass_true[col], mat[row, col]))
- _ = inds_true.remove(col)
-
- unmatched = [ass_true[ind] for ind in inds_true]
- return matched, unmatched
-
-
-def parse_ground_truth_data_file(h5_file):
- df = pd.read_hdf(h5_file)
- try:
- df.drop("single", axis=1, level="individuals", inplace=True)
- except KeyError:
- pass
- # Cast columns of dtype 'object' to float to avoid TypeError
- # further down in _parse_ground_truth_data.
- cols = df.select_dtypes(include="object").columns
- if cols.to_list():
- df[cols] = df[cols].astype("float")
- n_individuals = len(df.columns.get_level_values("individuals").unique())
- n_bodyparts = len(df.columns.get_level_values("bodyparts").unique())
- data = df.to_numpy().reshape((df.shape[0], n_individuals, n_bodyparts, -1))
- return _parse_ground_truth_data(data)
-
-
-def _parse_ground_truth_data(data):
- gt = dict()
- for i, arr in enumerate(data):
- temp = []
- for row in arr:
- if np.isnan(row[:, :2]).all():
- continue
- ass = Assembly.from_array(row)
- temp.append(ass)
- if not temp:
- continue
- gt[i] = temp
- return gt
-
-
-def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)):
- if not hasattr(Assembly, criterion):
- raise ValueError(f"Invalid criterion {criterion}.")
-
- if len(qs) != 2:
- raise ValueError(
- "Two percentiles (for lower and upper bounds) should be given."
- )
-
- tuples = []
- for frame_ind, assemblies in dict_of_assemblies.items():
- for assembly in assemblies:
- tuples.append((frame_ind, getattr(assembly, criterion)))
- frame_inds, vals = zip(*tuples)
- vals = np.asarray(vals)
- lo, up = np.percentile(vals, qs, interpolation="nearest")
- inds = np.flatnonzero((vals < lo) | (vals > up)).tolist()
- return list(set(frame_inds[i] for i in inds))
-
-
-def evaluate_assembly(
- ass_pred_dict,
- ass_true_dict,
- oks_sigma=0.072,
- oks_thresholds=np.linspace(0.5, 0.95, 10),
- margin=0,
- symmetric_kpts=None,
- greedy_matching=False,
-):
- # sigma is taken as the median of all COCO keypoint standard deviations
- all_matched = []
- all_unmatched = []
- for ind, ass_true in tqdm(ass_true_dict.items()):
- ass_pred = ass_pred_dict.get(ind, [])
- matched, unmatched = match_assemblies(
- ass_pred,
- ass_true,
- oks_sigma,
- margin,
- symmetric_kpts,
- greedy_matching,
- )
- all_matched.extend(matched)
- all_unmatched.extend(unmatched)
- if not all_matched:
- return {
- "precisions": np.array([]),
- "recalls": np.array([]),
- "mAP": 0.0,
- "mAR": 0.0,
- }
-
- conf_pred = np.asarray([match[0].affinity for match in all_matched])
- idx = np.argsort(-conf_pred, kind="mergesort")
- # Sort matching score (OKS) in descending order of assembly affinity
- oks = np.asarray([match[2] for match in all_matched])[idx]
- ntot = len(all_matched) + len(all_unmatched)
- recall_thresholds = np.linspace(0, 1, 101)
- precisions = []
- recalls = []
- for th in oks_thresholds:
- tp = np.cumsum(oks >= th)
- fp = np.cumsum(oks < th)
- rc = tp / ntot
- pr = tp / (fp + tp + np.spacing(1))
- recall = rc[-1]
- # Guarantee precision decreases monotonically
- # See https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173)
- for i in range(len(pr) - 1, 0, -1):
- if pr[i] > pr[i - 1]:
- pr[i - 1] = pr[i]
- inds_rc = np.searchsorted(rc, recall_thresholds)
- precision = np.zeros(inds_rc.shape)
- valid = inds_rc < len(pr)
- precision[valid] = pr[inds_rc[valid]]
- precisions.append(precision)
- recalls.append(recall)
- precisions = np.asarray(precisions)
- recalls = np.asarray(recalls)
- return {
- "precisions": precisions,
- "recalls": recalls,
- "mAP": precisions.mean(),
- "mAR": recalls.mean(),
- }
+"""Backwards compatibility"""
+from deeplabcut.core.inferenceutils import *
diff --git a/deeplabcut/pose_estimation_tensorflow/lib/trackingutils.py b/deeplabcut/pose_estimation_tensorflow/lib/trackingutils.py
index 7cc88a92bd..f769fa7238 100644
--- a/deeplabcut/pose_estimation_tensorflow/lib/trackingutils.py
+++ b/deeplabcut/pose_estimation_tensorflow/lib/trackingutils.py
@@ -8,829 +8,5 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-
-import abc
-import math
-import numpy as np
-import warnings
-from collections import defaultdict
-from filterpy.common import kinematic_kf
-from filterpy.kalman import KalmanFilter
-from matplotlib import patches
-from numba import jit
-from numba.core.errors import NumbaPerformanceWarning
-from scipy.optimize import linear_sum_assignment
-from scipy.stats import mode
-from tqdm import tqdm
-
-
-warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
-
-TRACK_METHODS = {
- "box": "_bx",
- "skeleton": "_sk",
- "ellipse": "_el",
- "transformer": "_tr",
-}
-
-
-def calc_iou(bbox1, bbox2):
- x1 = max(bbox1[0], bbox2[0])
- y1 = max(bbox1[1], bbox2[1])
- x2 = min(bbox1[2], bbox2[2])
- y2 = min(bbox1[3], bbox2[3])
- w = max(0, x2 - x1)
- h = max(0, y2 - y1)
- wh = w * h
- return wh / (
- (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
- + (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
- - wh
- )
-
-
-class BaseTracker:
- """Base class for a constant-velocity Kalman filter-based tracker."""
-
- n_trackers = 0
-
- def __init__(self, dim, dim_z):
- self.kf = kinematic_kf(
- dim,
- 1,
- dim_z=dim_z,
- order_by_dim=False,
- )
- self.id = self.__class__.n_trackers
- self.__class__.n_trackers += 1
- self.time_since_update = 0
- self.age = 0
- self.hits = 0
- self.hit_streak = 0
-
- def update(self, z):
- self.time_since_update = 0
- self.hits += 1
- self.hit_streak += 1
- self.kf.update(z)
-
- def predict(self):
- self.kf.predict()
- self.age += 1
- if self.time_since_update > 0:
- self.hit_streak = 0
- self.time_since_update += 1
- return self.state
-
- @property
- def state(self):
- return self.kf.x.squeeze()[: self.kf.dim_z]
-
- @state.setter
- def state(self, state):
- self.kf.x[: self.kf.dim_z] = state
-
-
-class Ellipse:
- def __init__(self, x, y, width, height, theta):
- self.x = x
- self.y = y
- self.width = width
- self.height = height
- self.theta = theta # in radians
- self._geometry = None
-
- @property
- def parameters(self):
- return self.x, self.y, self.width, self.height, self.theta
-
- @property
- def aspect_ratio(self):
- return max(self.width, self.height) / min(self.width, self.height)
-
- def calc_similarity_with(self, other_ellipse):
- max_dist = max(
- self.height, self.width, other_ellipse.height, other_ellipse.width
- )
- dist = math.sqrt(
- (self.x - other_ellipse.x) ** 2 + (self.y - other_ellipse.y) ** 2
- )
- cost1 = 1 - min(dist / max_dist, 1)
- cost2 = abs(math.cos(self.theta - other_ellipse.theta))
- return 0.8 * cost1 + 0.2 * cost2 * cost1
-
- def contains_points(self, xy, tol=0.1):
- ca = math.cos(self.theta)
- sa = math.sin(self.theta)
- x_demean = xy[:, 0] - self.x
- y_demean = xy[:, 1] - self.y
- return (
- ((ca * x_demean + sa * y_demean) ** 2 / (0.5 * self.width) ** 2)
- + ((sa * x_demean - ca * y_demean) ** 2 / (0.5 * self.height) ** 2)
- ) <= 1 + tol
-
- def draw(self, show_axes=True, ax=None, **kwargs):
- import matplotlib.pyplot as plt
- from matplotlib.lines import Line2D
- from matplotlib.transforms import Affine2D
-
- if ax is None:
- ax = plt.subplot(111, aspect="equal")
- el = patches.Ellipse(
- xy=(self.x, self.y),
- width=self.width,
- height=self.height,
- angle=np.rad2deg(self.theta),
- **kwargs,
- )
- ax.add_patch(el)
- if show_axes:
- major = Line2D([-self.width / 2, self.width / 2], [0, 0], lw=3, zorder=3)
- minor = Line2D([0, 0], [-self.height / 2, self.height / 2], lw=3, zorder=3)
- trans = (
- Affine2D().rotate(self.theta).translate(self.x, self.y) + ax.transData
- )
- major.set_transform(trans)
- minor.set_transform(trans)
- ax.add_artist(major)
- ax.add_artist(minor)
-
-
-class EllipseFitter:
- def __init__(self, sd=2):
- self.sd = sd
- self.x = None
- self.y = None
- self.params = None
- self._coeffs = None
-
- def fit(self, xy):
- self.x, self.y = xy[np.isfinite(xy).all(axis=1)].T
- if len(self.x) < 3:
- return None
- if self.sd:
- self.params = self._fit_error(self.x, self.y, self.sd)
- else:
- self._coeffs = self._fit(self.x, self.y)
- self.params = self.calc_parameters(self._coeffs)
- if not np.isnan(self.params).any():
- el = Ellipse(*self.params)
- # Regularize by forcing AR <= 5
- # max_ar = 5
- # if el.aspect_ratio >= max_ar:
- # if el.height > el.width:
- # el.width = el.height / max_ar
- # else:
- # el.height = el.width / max_ar
- # Orient the ellipse such that it encompasses most points
- # n_inside = el.contains_points(np.c_[self.x, self.y]).sum()
- # el.theta += 0.5 * np.pi
- # if el.contains_points(np.c_[self.x, self.y]).sum() < n_inside:
- # el.theta -= 0.5 * np.pi
- return el
- return None
-
- @staticmethod
- @jit(nopython=True)
- def _fit(x, y):
- """
- Least Squares ellipse fitting algorithm
- Fit an ellipse to a set of X- and Y-coordinates.
- See Halir and Flusser, 1998 for implementation details
-
- :param x: ndarray, 1D trajectory
- :param y: ndarray, 1D trajectory
- :return: 1D ndarray of 6 coefficients of the general quadratic curve:
- ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0
- """
- D1 = np.vstack((x * x, x * y, y * y))
- D2 = np.vstack((x, y, np.ones_like(x)))
- S1 = D1 @ D1.T
- S2 = D1 @ D2.T
- S3 = D2 @ D2.T
- T = -np.linalg.inv(S3) @ S2.T
- temp = S1 + S2 @ T
- M = np.zeros_like(temp)
- M[0] = temp[2] * 0.5
- M[1] = -temp[1]
- M[2] = temp[0] * 0.5
- E, V = np.linalg.eig(M)
- cond = 4 * V[0] * V[2] - V[1] ** 2
- a1 = V[:, cond > 0][:, 0]
- a2 = T @ a1
- return np.hstack((a1, a2))
-
- @staticmethod
- @jit(nopython=True)
- def _fit_error(x, y, sd):
- """
- Fit a sd-sigma covariance error ellipse to the data.
-
- :param x: ndarray, 1D input of X coordinates
- :param y: ndarray, 1D input of Y coordinates
- :param sd: int, size of the error ellipse in 'standard deviation'
- :return: ellipse center, semi-axes length, angle to the X-axis
- """
- cov = np.cov(x, y)
- E, V = np.linalg.eigh(cov) # Returns the eigenvalues in ascending order
- # r2 = chi2.ppf(2 * norm.cdf(sd) - 1, 2)
- # height, width = np.sqrt(E * r2)
- height, width = 2 * sd * np.sqrt(E)
- a, b = V[:, 1]
- rotation = math.atan2(b, a) % np.pi
- return [np.mean(x), np.mean(y), width, height, rotation]
-
- @staticmethod
- @jit(nopython=True)
- def calc_parameters(coeffs):
- """
- Calculate ellipse center coordinates, semi-axes lengths, and
- the counterclockwise angle of rotation from the x-axis to the ellipse major axis.
- Visit http://mathworld.wolfram.com/Ellipse.html
- for how to estimate ellipse parameters.
-
- :param coeffs: list of fitting coefficients
- :return: center: 1D ndarray, semi-axes: 1D ndarray, angle: float
- """
- # The general quadratic curve has the form:
- # ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0
- a, b, c, d, f, g = coeffs
- b *= 0.5
- d *= 0.5
- f *= 0.5
-
- # Ellipse center coordinates
- x0 = (c * d - b * f) / (b * b - a * c)
- y0 = (a * f - b * d) / (b * b - a * c)
-
- # Semi-axes lengths
- num = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g)
- den1 = (b * b - a * c) * (np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c))
- den2 = (b * b - a * c) * (-np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c))
- major = np.sqrt(num / den1)
- minor = np.sqrt(num / den2)
-
- # Angle to the horizontal
- if b == 0:
- if a < c:
- phi = 0
- else:
- phi = np.pi / 2
- else:
- if a < c:
- phi = np.arctan(2 * b / (a - c)) / 2
- else:
- phi = np.pi / 2 + np.arctan(2 * b / (a - c)) / 2
-
- return [x0, y0, 2 * major, 2 * minor, phi]
-
-
-class EllipseTracker(BaseTracker):
- def __init__(self, params):
- super().__init__(dim=5, dim_z=5)
- self.kf.R[2:, 2:] *= 10.0
- # High uncertainty to the unobservable initial velocities
- self.kf.P[5:, 5:] *= 1000.0
- self.kf.P *= 10.0
- self.kf.Q[5:, 5:] *= 0.01
- self.state = params
-
- @BaseTracker.state.setter
- def state(self, params):
- state = np.asarray(params).reshape((-1, 1))
- super(EllipseTracker, type(self)).state.fset(self, state)
-
-
-class SkeletonTracker(BaseTracker):
- def __init__(self, n_bodyparts):
- super().__init__(dim=n_bodyparts * 2, dim_z=n_bodyparts)
- self.kf.Q[self.kf.dim_z :, self.kf.dim_z :] *= 10
- self.kf.R[self.kf.dim_z :, self.kf.dim_z :] *= 0.01
- self.kf.P[self.kf.dim_z :, self.kf.dim_z :] *= 1000
-
- def update(self, pose):
- flat = pose.reshape((-1, 1))
- empty = np.isnan(flat).squeeze()
- if empty.any():
- H = self.kf.H.copy()
- H[empty] = 0
- flat[empty] = 0
- self.kf.update(flat, H=H)
- else:
- super().update(flat)
-
- @BaseTracker.state.setter
- def state(self, pose):
- curr_pose = pose.copy()
- empty = np.isnan(curr_pose).all(axis=1)
- if empty.any():
- fill = np.nanmean(pose, axis=0)
- curr_pose[empty] = fill
- super(SkeletonTracker, type(self)).state.fset(self, curr_pose.reshape((-1, 1)))
-
-
-class BoxTracker(BaseTracker):
- def __init__(self, bbox):
- super().__init__(dim=4, dim_z=4)
- self.kf = KalmanFilter(dim_x=7, dim_z=4)
- self.kf.F = np.array(
- [
- [1, 0, 0, 0, 1, 0, 0],
- [0, 1, 0, 0, 0, 1, 0],
- [0, 0, 1, 0, 0, 0, 1],
- [0, 0, 0, 1, 0, 0, 0],
- [0, 0, 0, 0, 1, 0, 0],
- [0, 0, 0, 0, 0, 1, 0],
- [0, 0, 0, 0, 0, 0, 1],
- ]
- )
- self.kf.H = np.array(
- [
- [1, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0],
- [0, 0, 0, 1, 0, 0, 0],
- ]
- )
- self.kf.R[2:, 2:] *= 10.0
- # Give high uncertainty to the unobservable initial velocities
- self.kf.P[4:, 4:] *= 1000.0
- self.kf.P *= 10.0
- self.kf.Q[-1, -1] *= 0.01
- self.kf.Q[4:, 4:] *= 0.01
- self.state = bbox
-
- def update(self, bbox):
- super().update(self.convert_bbox_to_z(bbox))
-
- def predict(self):
- if (self.kf.x[6] + self.kf.x[2]) <= 0:
- self.kf.x[6] *= 0.0
- return super().predict()
-
- @property
- def state(self):
- return self.convert_x_to_bbox(self.kf.x)[0]
-
- @state.setter
- def state(self, bbox):
- state = self.convert_bbox_to_z(bbox)
- super(BoxTracker, type(self)).state.fset(self, state)
-
- @staticmethod
- def convert_x_to_bbox(x, score=None):
- """
- Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
- [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
- """
- w = np.sqrt(x[2] * x[3])
- h = x[2] / w
- if score is None:
- return np.array(
- [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]
- ).reshape((1, 4))
- else:
- return np.array(
- [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]
- ).reshape((1, 5))
-
- @staticmethod
- def convert_bbox_to_z(bbox):
- """
- Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
- [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
- the aspect ratio
- """
- w = bbox[2] - bbox[0]
- h = bbox[3] - bbox[1]
- x = bbox[0] + w / 2.0
- y = bbox[1] + h / 2.0
- s = w * h # scale is just area
- r = w / float(h)
- return np.array([x, y, s, r]).reshape((4, 1))
-
-
-class SORTBase(metaclass=abc.ABCMeta):
- def __init__(self):
- self.n_frames = 0
- self.trackers = []
-
- @abc.abstractmethod
- def track(self):
- pass
-
-
-class SORTEllipse(SORTBase):
- def __init__(self, max_age, min_hits, iou_threshold, sd=2):
- self.max_age = max_age
- self.min_hits = min_hits
- self.iou_threshold = iou_threshold
- self.fitter = EllipseFitter(sd)
- EllipseTracker.n_trackers = 0
- super().__init__()
-
- def track(self, poses, identities=None):
- self.n_frames += 1
-
- trackers = np.zeros((len(self.trackers), 6))
- for i in range(len(trackers)):
- trackers[i, :5] = self.trackers[i].predict()
- empty = np.isnan(trackers).any(axis=1)
- trackers = trackers[~empty]
- for ind in np.flatnonzero(empty)[::-1]:
- self.trackers.pop(ind)
-
- ellipses = []
- pred_ids = []
- for i, pose in enumerate(poses):
- el = self.fitter.fit(pose)
- if el is not None:
- ellipses.append(el)
- if identities is not None:
- pred_ids.append(mode(identities[i])[0][0])
- if not len(trackers):
- matches = np.empty((0, 2), dtype=int)
- unmatched_detections = np.arange(len(ellipses))
- unmatched_trackers = np.empty((0, 6), dtype=int)
- else:
- ellipses_trackers = [Ellipse(*t[:5]) for t in trackers]
- cost_matrix = np.zeros((len(ellipses), len(ellipses_trackers)))
- for i, el in enumerate(ellipses):
- for j, el_track in enumerate(ellipses_trackers):
- cost = el.calc_similarity_with(el_track)
- if identities is not None:
- match = 2 if pred_ids[i] == self.trackers[j].id_ else 1
- cost *= match
- cost_matrix[i, j] = cost
- row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=True)
- unmatched_detections = [
- i for i, _ in enumerate(ellipses) if i not in row_indices
- ]
- unmatched_trackers = [
- j for j, _ in enumerate(trackers) if j not in col_indices
- ]
- matches = []
- for row, col in zip(row_indices, col_indices):
- val = cost_matrix[row, col]
- # diff = val - cost_matrix
- # diff[row, col] += val
- # if (
- # val < self.iou_threshold
- # or np.any(diff[row] <= 0.2)
- # or np.any(diff[:, col] <= 0.2)
- # ):
- if val < self.iou_threshold:
- unmatched_detections.append(row)
- unmatched_trackers.append(col)
- else:
- matches.append([row, col])
- if not len(matches):
- matches = np.empty((0, 2), dtype=int)
- else:
- matches = np.stack(matches)
- unmatched_trackers = np.asarray(unmatched_trackers)
- unmatched_detections = np.asarray(unmatched_detections)
-
- animalindex = []
- for t, tracker in enumerate(self.trackers):
- if t not in unmatched_trackers:
- ind = matches[matches[:, 1] == t, 0][0]
- animalindex.append(ind)
- tracker.update(ellipses[ind].parameters)
- else:
- animalindex.append(-1)
-
- for i in unmatched_detections:
- trk = EllipseTracker(ellipses[i].parameters)
- if identities is not None:
- trk.id_ = mode(identities[i])[0][0]
- self.trackers.append(trk)
- animalindex.append(i)
-
- i = len(self.trackers)
- ret = []
- for trk in reversed(self.trackers):
- d = trk.state
- if (trk.time_since_update < 1) and (
- trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits
- ):
- ret.append(
- np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape(
- 1, -1
- )
- ) # for DLC we also return the original animalid
- # +1 as MOT benchmark requires positive >> this is removed for DLC!
- i -= 1
- # remove dead tracklet
- if trk.time_since_update > self.max_age:
- self.trackers.pop(i)
-
- if len(ret) > 0:
- return np.concatenate(ret)
- return np.empty((0, 7))
-
-
-class SORTSkeleton(SORTBase):
- def __init__(self, n_bodyparts, max_age=20, min_hits=3, oks_threshold=0.5):
- self.n_bodyparts = n_bodyparts
- self.max_age = max_age
- self.min_hits = min_hits
- self.oks_threshold = oks_threshold
- SkeletonTracker.n_trackers = 0
- super().__init__()
-
- @staticmethod
- def weighted_hausdorff(x, y):
- # Modified from scipy source code:
- # - to restrict its use to 2D
- # - to get rid of shuffling (since arrays are only (nbodyparts * 3) element long)
- # TODO - factor in keypoint confidence (and weight by # of observations??)
- cmax = 0
- for i in range(x.shape[0]):
- no_break_occurred = True
- cmin = np.inf
- for j in range(y.shape[0]):
- d = (x[i, 0] - y[j, 0]) ** 2 + (x[i, 1] - y[j, 1]) ** 2
- if d < cmax:
- no_break_occurred = False
- break
- if d < cmin:
- cmin = d
- if cmin != np.inf and cmin > cmax and no_break_occurred:
- cmax = cmin
- return np.sqrt(cmax)
-
- @staticmethod
- def object_keypoint_similarity(x, y):
- mask = ~np.isnan(x * y).all(axis=1) # Intersection visible keypoints
- xx = x[mask]
- yy = y[mask]
- dist = np.linalg.norm(xx - yy, axis=1)
- scale = np.sqrt(
- np.product(np.ptp(yy, axis=0))
- ) # square root of bounding box area
- oks = np.exp(-0.5 * (dist / (0.05 * scale)) ** 2)
- return np.mean(oks)
-
- def calc_pairwise_hausdorff_dist(self, poses, poses_ref):
- mat = np.zeros((len(poses), len(poses_ref)))
- for i, pose in enumerate(poses):
- for j, pose_ref in enumerate(poses_ref):
- mat[i, j] = self.weighted_hausdorff(pose, pose_ref)
- return mat
-
- def calc_pairwise_oks(self, poses, poses_ref):
- mat = np.zeros((len(poses), len(poses_ref)))
- for i, pose in enumerate(poses):
- for j, pose_ref in enumerate(poses_ref):
- mat[i, j] = self.object_keypoint_similarity(pose, pose_ref)
- return mat
-
- def track(self, poses):
- self.n_frames += 1
-
- if not len(self.trackers):
- for pose in poses:
- tracker = SkeletonTracker(self.n_bodyparts)
- tracker.state = pose
- self.trackers.append(tracker)
-
- poses_ref = []
- for i, tracker in enumerate(self.trackers):
- pose_ref = tracker.predict()
- poses_ref.append(pose_ref.reshape((-1, 2)))
-
- # mat = self.calc_pairwise_oks(poses, poses_ref)
- mat = self.calc_pairwise_hausdorff_dist(poses, poses_ref)
- row_indices, col_indices = linear_sum_assignment(mat, maximize=False)
-
- unmatched_poses = [p for p, _ in enumerate(poses) if p not in row_indices]
- unmatched_trackers = [
- t for t, _ in enumerate(poses_ref) if t not in col_indices
- ]
- # Remove matched detections with low OKS
- # matches = []
- # for row, col in zip(row_indices, col_indices):
- # if mat[row, col] < self.oks_threshold:
- # unmatched_poses.append(row)
- # unmatched_trackers.append(col)
- # else:
- # matches.append([row, col])
- # if not len(matches):
- # matches = np.empty((0, 2), dtype=int)
- # else:
- # matches = np.stack(matches)
- matches = np.c_[row_indices, col_indices]
-
- animalindex = []
- for t, tracker in enumerate(self.trackers):
- if t not in unmatched_trackers:
- ind = matches[matches[:, 1] == t, 0][0]
- animalindex.append(ind)
- tracker.update(poses[ind])
- else:
- animalindex.append(-1)
-
- for i in unmatched_poses:
- tracker = SkeletonTracker(self.n_bodyparts)
- tracker.state = poses[i]
- self.trackers.append(tracker)
- animalindex.append(i)
-
- states = []
- i = len(self.trackers)
- for tracker in reversed(self.trackers):
- i -= 1
- if tracker.time_since_update > self.max_age:
- self.trackers.pop()
- continue
- state = tracker.predict()
- states.append(np.r_[state, [tracker.id, int(animalindex[i])]])
- if len(states) > 0:
- return np.stack(states)
- return np.empty((0, self.n_bodyparts * 2 + 2))
-
-
-class SORTBox(SORTBase):
- def __init__(self, max_age, min_hits, iou_threshold):
- self.max_age = max_age
- self.min_hits = min_hits
- self.iou_threshold = iou_threshold
- BoxTracker.n_trackers = 0
- super().__init__()
-
- def track(self, dets):
- self.n_frames += 1
-
- trackers = np.zeros((len(self.trackers), 5))
- for i in range(len(trackers)):
- trackers[i, :4] = self.trackers[i].predict()
- empty = np.isnan(trackers).any(axis=1)
- trackers = trackers[~empty]
- for ind in np.flatnonzero(empty)[::-1]:
- self.trackers.pop(ind)
-
- matched, unmatched_dets, unmatched_trks = self.match_detections_to_trackers(
- dets, trackers, self.iou_threshold
- )
-
- # update matched trackers with assigned detections
- animalindex = []
- for t, trk in enumerate(self.trackers):
- if t not in unmatched_trks:
- d = matched[np.where(matched[:, 1] == t)[0], 0]
- animalindex.append(d[0])
- trk.update(dets[d, :][0]) # update coordinates
- else:
- animalindex.append("nix") # lost trk!
-
- # create and initialise new trackers for unmatched detections
- for i in unmatched_dets:
- trk = BoxTracker(dets[i, :])
- self.trackers.append(trk)
- animalindex.append(i)
-
- i = len(self.trackers)
- ret = []
- for trk in reversed(self.trackers):
- d = trk.state
- if (trk.time_since_update < 1) and (
- trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits
- ):
- ret.append(
- np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape(
- 1, -1
- )
- ) # for DLC we also return the original animalid
- # +1 as MOT benchmark requires positive >> this is removed for DLC!
- i -= 1
- # remove dead tracklet
- if trk.time_since_update > self.max_age:
- self.trackers.pop(i)
-
- if len(ret) > 0:
- return np.concatenate(ret)
- return np.empty((0, 5))
-
- @staticmethod
- def match_detections_to_trackers(detections, trackers, iou_threshold):
- """
- Assigns detections to tracked object (both represented as bounding boxes)
-
- Returns 3 lists of matches, unmatched_detections and unmatched_trackers
- """
- if not len(trackers):
- return (
- np.empty((0, 2), dtype=int),
- np.arange(len(detections)),
- np.empty((0, 5), dtype=int),
- )
- iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32)
-
- for d, det in enumerate(detections):
- for t, trk in enumerate(trackers):
- iou_matrix[d, t] = calc_iou(det, trk)
- row_indices, col_indices = linear_sum_assignment(-iou_matrix)
-
- unmatched_detections = []
- for d, det in enumerate(detections):
- if d not in row_indices:
- unmatched_detections.append(d)
- unmatched_trackers = []
- for t, trk in enumerate(trackers):
- if t not in col_indices:
- unmatched_trackers.append(t)
-
- # filter out matched with low IOU
- matches = []
- for row, col in zip(row_indices, col_indices):
- if iou_matrix[row, col] < iou_threshold:
- unmatched_detections.append(row)
- unmatched_trackers.append(col)
- else:
- matches.append([row, col])
- if not len(matches):
- matches = np.empty((0, 2), dtype=int)
- else:
- matches = np.stack(matches)
- return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
-
-
-def fill_tracklets(tracklets, trackers, animals, imname):
- for content in trackers:
- tracklet_id, pred_id = content[-2:].astype(int)
- if tracklet_id not in tracklets:
- tracklets[tracklet_id] = {}
- if pred_id != -1:
- tracklets[tracklet_id][imname] = animals[pred_id]
- else: # Resort to the tracker prediction
- xy = np.asarray(content[:-2])
- pred = np.insert(xy, range(2, len(xy) + 1, 2), 1)
- tracklets[tracklet_id][imname] = pred
-
-
-def calc_bboxes_from_keypoints(data, slack=0, offset=0):
- data = np.asarray(data)
- if data.shape[-1] < 3:
- raise ValueError("Data should be of shape (n_animals, n_bodyparts, 3)")
-
- if data.ndim != 3:
- data = np.expand_dims(data, axis=0)
- bboxes = np.full((data.shape[0], 5), np.nan)
- bboxes[:, :2] = np.nanmin(data[..., :2], axis=1) - slack # X1, Y1
- bboxes[:, 2:4] = np.nanmax(data[..., :2], axis=1) + slack # X2, Y2
- bboxes[:, -1] = np.nanmean(data[..., 2]) # Average confidence
- bboxes[:, [0, 2]] += offset
- return bboxes
-
-
-def reconstruct_all_ellipses(data, sd):
- xy = data.droplevel("scorer", axis=1).drop("likelihood", axis=1, level=-1)
- if "single" in xy:
- xy.drop("single", axis=1, level="individuals", inplace=True)
- animals = xy.columns.get_level_values("individuals").unique()
- nrows = xy.shape[0]
- ellipses = np.full((len(animals), nrows, 5), np.nan)
- fitter = EllipseFitter(sd)
- for n, animal in enumerate(animals):
- data = xy.xs(animal, axis=1, level="individuals").values.reshape((nrows, -1, 2))
- for i, coords in enumerate(tqdm(data)):
- el = fitter.fit(coords.astype(np.float64))
- if el is not None:
- ellipses[n, i] = el.parameters
- return ellipses
-
-
-def _track_individuals(
- individuals, min_hits=1, max_age=5, similarity_threshold=0.6, track_method="ellipse"
-):
- if track_method not in TRACK_METHODS:
- raise ValueError(f"Unknown {track_method} tracker.")
-
- if track_method == "ellipse":
- tracker = SORTEllipse(max_age, min_hits, similarity_threshold)
- elif track_method == "box":
- tracker = SORTBox(max_age, min_hits, similarity_threshold)
- else:
- n_bodyparts = individuals[0][0].shape[0]
- tracker = SORTSkeleton(n_bodyparts, max_age, min_hits, similarity_threshold)
-
- tracklets = defaultdict(dict)
- all_hyps = dict()
- for i, (multi, single) in enumerate(tqdm(individuals)):
- if single is not None:
- tracklets["single"][i] = single
- if multi is None:
- continue
- if track_method == "box":
- # TODO: get cropping parameters and utilize!
- xy = calc_bboxes_from_keypoints(multi)
- else:
- xy = multi[..., :2]
- hyps = tracker.track(xy)
- all_hyps[i] = hyps
- for hyp in hyps:
- tracklet_id, pred_id = hyp[-2:].astype(int)
- if pred_id != -1:
- tracklets[tracklet_id][i] = multi[pred_id]
- return tracklets, all_hyps
+"""Backwards compatibility"""
+from deeplabcut.core.trackingutils import *
diff --git a/deeplabcut/modelzoo/api/__init__.py b/deeplabcut/pose_estimation_tensorflow/modelzoo/api/__init__.py
similarity index 100%
rename from deeplabcut/modelzoo/api/__init__.py
rename to deeplabcut/pose_estimation_tensorflow/modelzoo/api/__init__.py
diff --git a/deeplabcut/modelzoo/api/spatiotemporal_adapt.py b/deeplabcut/pose_estimation_tensorflow/modelzoo/api/spatiotemporal_adapt.py
similarity index 76%
rename from deeplabcut/modelzoo/api/spatiotemporal_adapt.py
rename to deeplabcut/pose_estimation_tensorflow/modelzoo/api/spatiotemporal_adapt.py
index bef6146ba2..76ed897313 100644
--- a/deeplabcut/modelzoo/api/spatiotemporal_adapt.py
+++ b/deeplabcut/pose_estimation_tensorflow/modelzoo/api/spatiotemporal_adapt.py
@@ -8,13 +8,21 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-import deeplabcut
import glob
import os
-from deeplabcut.modelzoo.utils import parse_available_supermodels
-from deeplabcut.modelzoo.api import superanimal_inference
-from deeplabcut.utils.plotting import _plot_trajectories
+import io
from pathlib import Path
+import yaml
+from deeplabcut.pose_estimation_tensorflow.modelzoo.api.superanimal_inference import (
+ video_inference,
+)
+from deeplabcut.utils.auxiliaryfunctions import (
+ get_deeplabcut_path,
+ load_analyzed_data,
+ read_config,
+)
+from deeplabcut.utils.make_labeled_video import create_labeled_video
+from deeplabcut.utils.plotting import _plot_trajectories
class SpatiotemporalAdaptation:
@@ -72,12 +80,6 @@ def __init__(
if scale_list is None:
scale_list = []
- supermodels = parse_available_supermodels()
- if supermodel_name not in supermodels:
- raise ValueError(
- f"`supermodel_name` should be one of: {', '.join(supermodels)}."
- )
-
self.video_path = video_path
self.supermodel_name = supermodel_name
self.scale_list = scale_list
@@ -88,32 +90,53 @@ def __init__(
self.modelfolder = modelfolder
self.init_weights = init_weights
+ project_name = "_".join(supermodel_name.split("_")[:-1])
+ model_name = supermodel_name.split("_")[-1]
+ self.project_name = project_name
+ self.model_name = model_name
+
if not customized_pose_config:
- dlc_root_path = os.sep.join(deeplabcut.__file__.split(os.sep)[:-1])
- self.customized_pose_config = os.path.join(
- dlc_root_path,
- "pose_estimation_tensorflow",
- "superanimal_configs",
- supermodels[self.supermodel_name],
+ dlc_root_path = get_deeplabcut_path()
+
+ project_config = read_config(
+ os.path.join(
+ dlc_root_path, "modelzoo", "project_configs", f"{project_name}.yaml"
+ )
+ )
+
+ model_config = read_config(
+ os.path.join(
+ dlc_root_path, "modelzoo", "model_configs", f"{model_name}.yaml"
+ )
)
+
+ joints = [i for i in range(len(project_config["bodyparts"]))]
+ num_joints = len(joints)
+ model_config["all_joints"] = joints
+ model_config["all_joints_names"] = project_config["bodyparts"]
+ model_config["num_joints"] = num_joints
+ model_config["num_limbs"] = int((num_joints * (num_joints - 1)) // 2)
+ self.customized_pose_config = {**project_config, **model_config}
else:
self.customized_pose_config = customized_pose_config
def before_adapt_inference(self, make_video=False, **kwargs):
if self.init_weights != "":
print("using customized weights", self.init_weights)
- _, datafiles = superanimal_inference.video_inference(
+ _, datafiles = video_inference(
[self.video_path],
- self.supermodel_name,
+ self.project_name,
+ self.model_name,
videotype=self.videotype,
scale_list=self.scale_list,
init_weights=self.init_weights,
customized_test_config=self.customized_pose_config,
)
else:
- self.init_weights, datafiles = superanimal_inference.video_inference(
+ self.init_weights, datafiles = video_inference(
[self.video_path],
- self.supermodel_name,
+ self.project_name,
+ self.model_name,
videotype=self.videotype,
scale_list=self.scale_list,
customized_test_config=self.customized_pose_config,
@@ -125,14 +148,14 @@ def before_adapt_inference(self, make_video=False, **kwargs):
_plot_trajectories(datafiles[0])
if make_video:
- deeplabcut.create_labeled_video(
+ create_labeled_video(
"",
[self.video_path],
videotype=self.videotype,
filtered=False,
init_weights=self.init_weights,
draw_skeleton=True,
- superanimal_name=self.supermodel_name,
+ superanimal_name=self.project_name,
**kwargs,
)
@@ -167,7 +190,7 @@ def adaptation_training(self, displayiters=500, saveiters=1000, **kwargs):
vname = str(Path(self.video_path).stem)
video_root = Path(self.video_path).parent
- _, pseudo_label_path, _, _ = deeplabcut.auxiliaryfunctions.load_analyzed_data(
+ _, pseudo_label_path, _, _ = load_analyzed_data(
video_root, vname, DLCscorer, False, ""
)
if self.modelfolder != "":
@@ -175,19 +198,13 @@ def adaptation_training(self, displayiters=500, saveiters=1000, **kwargs):
self.adapt_iterations = kwargs.get("adapt_iterations", self.adapt_iterations)
- if os.path.exists(
- os.path.join(self.modelfolder, f"snapshot-{self.adapt_iterations}.index")
- ):
- print(
- f"model checkpoint snapshot-{self.adapt_iterations}.index exists, skipping the video adaptation"
- )
- else:
- self.train_without_project(
- pseudo_label_path,
- displayiters=displayiters,
- saveiters=saveiters,
- **kwargs,
- )
+
+ self.train_without_project(
+ pseudo_label_path,
+ displayiters=displayiters,
+ saveiters=saveiters,
+ **kwargs,
+ )
def after_adapt_inference(self, **kwargs):
pattern = os.path.join(
@@ -208,9 +225,10 @@ def after_adapt_inference(self, **kwargs):
# spatial pyramid can still be useful for reducing jittering and quantization error
- _, datafiles = superanimal_inference.video_inference(
+ _, datafiles = video_inference(
[self.video_path],
- self.supermodel_name,
+ self.project_name,
+ self.model_name,
videotype=self.videotype,
init_weights=adapt_weights,
scale_list=scale_list,
@@ -220,13 +238,13 @@ def after_adapt_inference(self, **kwargs):
if kwargs.pop("plot_trajectories", True):
_plot_trajectories(datafiles[0])
- deeplabcut.create_labeled_video(
+ create_labeled_video(
ref_proj_config_path,
[self.video_path],
videotype=self.videotype,
filtered=False,
init_weights=adapt_weights,
draw_skeleton=True,
- superanimal_name=self.supermodel_name,
+ superanimal_name=self.project_name,
**kwargs,
)
diff --git a/deeplabcut/modelzoo/api/superanimal_inference.py b/deeplabcut/pose_estimation_tensorflow/modelzoo/api/superanimal_inference.py
similarity index 72%
rename from deeplabcut/modelzoo/api/superanimal_inference.py
rename to deeplabcut/pose_estimation_tensorflow/modelzoo/api/superanimal_inference.py
index cf28974d7e..a22f93c828 100644
--- a/deeplabcut/modelzoo/api/superanimal_inference.py
+++ b/deeplabcut/pose_estimation_tensorflow/modelzoo/api/superanimal_inference.py
@@ -8,10 +8,12 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
+import glob
import os
import os.path
import pickle
import time
+import warnings
from pathlib import Path
import imgaug.augmenters as iaa
@@ -20,18 +22,11 @@
from skimage.util import img_as_ubyte
from tqdm import tqdm
-from deeplabcut.modelzoo.utils import parse_available_supermodels
from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.pose_estimation_tensorflow.core import predict as single_predict
from deeplabcut.pose_estimation_tensorflow.core import predict_multianimal as predict
from deeplabcut.utils import auxiliaryfunctions
from deeplabcut.utils.auxfun_videos import VideoWriter
-from dlclibrary.dlcmodelzoo.modelzoo_download import (
- download_huggingface_model,
- MODELOPTIONS,
-)
-import glob
-import warnings
warnings.simplefilter("ignore", category=RuntimeWarning)
@@ -260,7 +255,8 @@ def _video_inference(
def video_inference(
videos,
- superanimal_name,
+ project_name,
+ model_name,
scale_list=[],
videotype="avi",
destfolder=None,
@@ -270,48 +266,54 @@ def video_inference(
init_weights="",
customized_test_config="",
):
- if superanimal_name not in MODELOPTIONS:
- raise ValueError(
- f"{superanimal_name} not available. Available ones are: {MODELOPTIONS}. If you are confident `superanimal_name` is right, try updating `dlclibrary` with `pip install -U dlclibrary`."
- )
-
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
if customized_test_config == "":
- supermodels = parse_available_supermodels()
- test_cfg = load_config(
+ project_cfg = load_config(
os.path.join(
dlc_root_path,
- "pose_estimation_tensorflow",
- "superanimal_configs",
- supermodels[superanimal_name],
+ "modelzoo",
+ "project_configs",
+ f"{project_name}.yaml",
)
)
+ model_cfg = load_config(
+ os.path.join(
+ dlc_root_path,
+ "modelzoo",
+ "model_configs",
+ f"{model_name}.yaml",
+ )
+ )
+ test_cfg = {**project_cfg, **model_cfg}
+ test_cfg["all_joints"] = [i for i in range(len(test_cfg["bobyparts"]))]
+ test_cfg["all_joints_names"] = test_cfg["bobyparts"]
+ num_joints = len(test_cfg["all_joints"])
+ test_cfg["num_joints"] = num_joints
+ test_cfg["num_limbs"] = int((num_joints * (num_joints - 1)) // 2)
+
else:
- test_cfg = load_config(customized_test_config)
+ test_cfg = customized_test_config
# add a temp folder for checkpoint
weight_folder = str(
Path(dlc_root_path)
- / "pose_estimation_tensorflow"
- / "models"
- / "pretrained"
- / (superanimal_name + "_weights")
+ / "modelzoo"
+ / "checkpoints"
+ / f"{project_name}_{model_name}"
)
- pat = os.path.join(weight_folder, "snapshot-*.index")
- snapshots = glob.glob(pat)
- if not len(snapshots):
- download_huggingface_model(superanimal_name, weight_folder)
- snapshots = glob.glob(pat)
- else:
- print(f"{weight_folder} exists, using the downloaded weights")
-
+ snapshots = glob.glob(os.path.join(weight_folder, "snapshot-*.index"))
test_cfg["partaffinityfield_graph"] = []
test_cfg["partaffinityfield_predict"] = False
if init_weights != "":
test_cfg["init_weights"] = init_weights
else:
+ if len(snapshots) == 0:
+ raise FileNotFoundError(
+ f"Did not find any super animal snapshots in {weight_folder}"
+ )
+
init_weights = os.path.abspath(snapshots[0]).replace(".index", "")
test_cfg["init_weights"] = init_weights
@@ -437,3 +439,102 @@ def video_inference(
df.to_hdf(dataname, key="df_with_missing")
return init_weights, datafiles
+
+
+def _video_inference_superanimal(
+ videos,
+ project_name,
+ model_name,
+ scale_list=[],
+ videotype=".mp4",
+ video_adapt=False,
+ plot_trajectories=True,
+ pcutoff=0.1,
+ adapt_iterations=1000,
+ pseudo_threshold=0.1,
+):
+ """
+ WARNING: This function is an internal utility function and should not be
+ called directly. It is designed to be used by deeplabcut.modelzoo.api.video_inference.py
+
+ Makes prediction based on a super animal model. Note right now we only support single animal video inference
+
+ The index of the trained network is specified by parameters in the config file (in particular the variable 'snapshotindex')
+
+ Output: The labels are stored as MultiIndex Pandas Array, which contains the name of the network, body part name, (x, y) label position \n
+ in pixels, and the likelihood for each frame per body part. These arrays are stored in an efficient Hierarchical Data Format (HDF) \n
+ in the same directory, where the video is stored.
+
+ Parameters
+ ----------
+ videos: list
+ A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.
+
+ superanimal_name: str
+ The name of the superanimal model. We currently only support "superanimal_quadruped" and "superanimal_topviewmouse"
+ scale_list: list
+ A list of int containing the target height of the multi scale test time augmentation. By default it uses the original size. Users are advised to try a wide range of scale list when the super model does not give reasonable results
+
+ videotype: string, optional
+ Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed. The default is ``.avi``
+
+ video_adapt: bool, optional
+ Set True if you want to apply video adaptation to make the resulted video less jittering and better. However, adaptation training takes more time than usual video inference
+
+ plot_trajectories: bool, optional (default=True)
+ By default, plot the trajectories of various body parts across the video.
+
+ pcutoff: float, optional
+ Keypoints confidence that are under pcutoff will not be shown in the resulted video
+
+ adapt_iterations: int, optional:
+ Number of iterations for adaptation training
+
+ pseudo_threshold: float, default 0.1
+ Video adaptation only uses predictions that are above pseudo_threshold
+
+ Given a list of scales for spatial pyramid, i.e. [600, 700]
+
+ scale_list = range(600,800,100)
+
+ superanimal_name = 'superanimal_topviewmouse'
+ videotype = 'mp4'
+ scale_list = [200, 300, 400]
+ deeplabcut.video_inference_superanimal(
+ video,
+ superanimal_name,
+ videotype = '.avi',
+ scale_list = scale_list,
+ )
+ >>>
+ """
+ from deeplabcut.pose_estimation_tensorflow.modelzoo.api import (
+ SpatiotemporalAdaptation,
+ )
+
+ superanimal_name = project_name + "_" + model_name
+ for video in videos:
+ modelfolder = Path(video).parent / f"{Path(video).stem}_video_adaptation"
+ modelfolder.mkdir(exist_ok=True, parents=True)
+
+ adapter = SpatiotemporalAdaptation(
+ video,
+ superanimal_name,
+ modelfolder=str(modelfolder),
+ videotype=video.split(".")[-1],
+ scale_list=scale_list,
+ )
+ if not video_adapt:
+ adapter.before_adapt_inference(
+ make_video=True, pcutoff=pcutoff, plot_trajectories=plot_trajectories
+ )
+ else:
+ adapter.before_adapt_inference(make_video=False)
+ adapter.adaptation_training(
+ adapt_iterations=adapt_iterations,
+ pseudo_threshold=pseudo_threshold,
+ )
+ adapter.after_adapt_inference(
+ pcutoff=pcutoff,
+ plot_trajectories=plot_trajectories,
+ )
diff --git a/deeplabcut/pose_estimation_tensorflow/predict_supermodel.py b/deeplabcut/pose_estimation_tensorflow/predict_supermodel.py
deleted file mode 100644
index 1933e013b3..0000000000
--- a/deeplabcut/pose_estimation_tensorflow/predict_supermodel.py
+++ /dev/null
@@ -1,110 +0,0 @@
-#
-# DeepLabCut Toolbox (deeplabcut.org)
-# © A. & M.W. Mathis Labs
-# https://github.com/DeepLabCut/DeepLabCut
-#
-# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
-#
-# Licensed under GNU Lesser General Public License v3.0
-#
-from pathlib import Path
-from deeplabcut.modelzoo.api import SpatiotemporalAdaptation
-
-
-def video_inference_superanimal(
- videos,
- superanimal_name,
- scale_list=[],
- videotype=".mp4",
- video_adapt=False,
- plot_trajectories=True,
- pcutoff=0.1,
- adapt_iterations=1000,
- pseudo_threshold=0.1,
-):
- """
- Makes prediction based on a super animal model. Note right now we only support single animal video inference
-
- The index of the trained network is specified by parameters in the config file (in particular the variable 'snapshotindex')
-
- Output: The labels are stored as MultiIndex Pandas Array, which contains the name of the network, body part name, (x, y) label position \n
- in pixels, and the likelihood for each frame per body part. These arrays are stored in an efficient Hierarchical Data Format (HDF) \n
- in the same directory, where the video is stored.
-
- Parameters
- ----------
- videos: list
- A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.
-
- superanimal_name: str
- The name of the superanimal model. We currently only support "superanimal_quadruped" and "superanimal_topviewmouse"
- scale_list: list
- A list of int containing the target height of the multi scale test time augmentation. By default it uses the original size. Users are advised to try a wide range of scale list when the super model does not give reasonable results
-
- videotype: string, optional
- Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed. The default is ``.avi``
-
- video_adapt: bool, optional
- Set True if you want to apply video adaptation to make the resulted video less jittering and better. However, adaptation training takes more time than usual video inference
-
- plot_trajectories: bool, optional (default=True)
- By default, plot the trajectories of various body parts across the video.
-
- pcutoff: float, optional
- Keypoints confidence that are under pcutoff will not be shown in the resulted video
-
- adapt_iterations: int, optional:
- Number of iterations for adaptation training
-
- pseudo_threshold: float, default 0.1
- Video adaptation only uses predictions that are above pseudo_threshold
-
- Given a list of scales for spatial pyramid, i.e. [600, 700]
-
- scale_list = range(600,800,100)
-
- superanimal_name = 'superanimal_topviewmouse'
- videotype = 'mp4'
- scale_list = [200, 300, 400]
- deeplabcut.video_inference_superanimal(
- video,
- superanimal_name,
- videotype = '.avi',
- scale_list = scale_list,
- )
- >>>
- """
- from deeplabcut.utils.auxiliaryfunctions import get_deeplabcut_path
-
- for video in videos:
- vname = Path(video).stem
- dlcparent_path = get_deeplabcut_path()
- modelfolder = (
- Path(dlcparent_path)
- / "pose_estimation_tensorflow"
- / "models"
- / "pretrained"
- / (superanimal_name + "_" + vname + "_weights")
- )
- adapter = SpatiotemporalAdaptation(
- video,
- superanimal_name,
- modelfolder=modelfolder,
- videotype=videotype,
- scale_list=scale_list,
- )
- if not video_adapt:
- adapter.before_adapt_inference(
- make_video=True, pcutoff=pcutoff, plot_trajectories=plot_trajectories
- )
- else:
- adapter.before_adapt_inference(make_video=False)
- adapter.adaptation_training(
- adapt_iterations=adapt_iterations,
- pseudo_threshold=pseudo_threshold,
- )
- adapter.after_adapt_inference(
- pcutoff=pcutoff,
- plot_trajectories=plot_trajectories,
- )
diff --git a/deeplabcut/pose_estimation_tensorflow/predict_videos.py b/deeplabcut/pose_estimation_tensorflow/predict_videos.py
index f40f633657..412b25e367 100644
--- a/deeplabcut/pose_estimation_tensorflow/predict_videos.py
+++ b/deeplabcut/pose_estimation_tensorflow/predict_videos.py
@@ -31,9 +31,9 @@
from skimage.util import img_as_ubyte
from tqdm import tqdm
+from deeplabcut.core import trackingutils, inferenceutils
from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.pose_estimation_tensorflow.core import predict
-from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils, trackingutils
from deeplabcut.refine_training_dataset.stitch import stitch_tracklets
from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal, auxfun_models
diff --git a/deeplabcut/pose_estimation_tensorflow/superanimal_configs/superquadruped.yaml b/deeplabcut/pose_estimation_tensorflow/superanimal_configs/superquadruped.yaml
deleted file mode 100644
index b6088f9c9b..0000000000
--- a/deeplabcut/pose_estimation_tensorflow/superanimal_configs/superquadruped.yaml
+++ /dev/null
@@ -1,150 +0,0 @@
-all_joints:
-- - 0
-- - 1
-- - 2
-- - 3
-- - 4
-- - 5
-- - 6
-- - 7
-- - 8
-- - 9
-- - 10
-- - 11
-- - 12
-- - 13
-- - 14
-- - 15
-- - 16
-- - 17
-- - 18
-- - 19
-- - 20
-- - 21
-- - 22
-- - 23
-- - 24
-- - 25
-- - 26
-- - 27
-- - 28
-- - 29
-- - 30
-- - 31
-- - 32
-- - 33
-- - 34
-- - 35
-- - 36
-- - 37
-- - 38
-all_joints_names:
-- nose
-- upper_jaw
-- lower_jaw
-- mouth_end_right
-- mouth_end_left
-- right_eye
-- right_earbase
-- right_earend
-- right_antler_base
-- right_antler_end
-- left_eye
-- left_earbase
-- left_earend
-- left_antler_base
-- left_antler_end
-- neck_base
-- neck_end
-- throat_base
-- throat_end
-- back_base
-- back_end
-- back_middle
-- tail_base
-- tail_end
-- front_left_thai
-- front_left_knee
-- front_left_paw
-- front_right_thai
-- front_right_knee
-- front_right_paw
-- back_left_paw
-- back_left_thai
-- back_right_thai
-- back_left_knee
-- back_right_knee
-- back_right_paw
-- belly_bottom
-- body_middle_right
-- body_middle_left
-alpha_r: 0.02
-apply_prob: 0.5
-batch_size: 1
-clahe: true
-claheratio: 0.1
-crop_sampling: hybrid
-crop_size:
-- 400
-- 400
-cropratio: 0.4
-dataset: training-datasets/iteration-0/UnaugmentedDataSet_ma_superquadrupedMarch30/ma_superquadruped_maDLC_scorer85shuffle1.pickle
-dataset_type: multi-animal-imgaug
-decay_steps: 30000
-display_iters: 500
-edge: false
-emboss:
- alpha:
- - 0.0
- - 1.0
- embossratio: 0.1
- strength:
- - 0.5
- - 1.5
-global_scale: 0.8
-histeq: true
-histeqratio: 0.1
-init_weights:
-intermediate_supervision: false
-intermediate_supervision_layer: 12
-location_refinement: true
-locref_huber_loss: true
-locref_loss_weight: 0.05
-locref_stdev: 7.2801
-lr_init: 0.0005
-max_input_size: 1500
-max_shift: 0.4
-metadataset: training-datasets/iteration-0/UnaugmentedDataSet_ma_superquadrupedMarch30/Documentation_data-ma_superquadruped_85shuffle1.pickle
-min_input_size: 64
-mirror: false
-multi_stage: true
-multi_step:
-- - 0.0001
- - 7500
-- - 5.0e-05
- - 12000
-- - 1.0e-05
- - 1000000
-net_type: resnet_50
-num_idchannel: 0
-num_joints: 39
-num_limbs: 741
-optimizer: adam
-pafwidth: 20
-pairwise_huber_loss: false
-pairwise_loss_weight: 0.1
-pairwise_predict: false
-partaffinityfield_graph: []
-partaffinityfield_predict: false
-pos_dist_thresh: 17
-pre_resize: []
-project_path:
-rotation: 25
-rotratio: 0.4
-save_iters: 10000
-scale_jitter_lo: 0.5
-scale_jitter_up: 1.25
-sharpen: false
-sharpenratio: 0.3
-weigh_only_present_joints: false
-gradient_masking: true
diff --git a/deeplabcut/pose_estimation_tensorflow/visualizemaps.py b/deeplabcut/pose_estimation_tensorflow/visualizemaps.py
index ad6ce14779..111a695277 100644
--- a/deeplabcut/pose_estimation_tensorflow/visualizemaps.py
+++ b/deeplabcut/pose_estimation_tensorflow/visualizemaps.py
@@ -8,12 +8,15 @@
#
# Licensed under GNU Lesser General Public License v3.0
#
-
-
import os
import matplotlib.pyplot as plt
-import numpy as np
from skimage.transform import resize
+from deeplabcut.core.visualization import (
+ form_figure, # for backwards compatibility
+ visualize_scoremaps,
+ visualize_locrefs,
+ visualize_paf,
+)
def extract_maps(
@@ -268,79 +271,6 @@ def resize_all_maps(image, scmap, locref, paf):
return scmap, (locref_x, locref_y), paf
-def form_figure(nx, ny):
- fig, ax = plt.subplots(frameon=False)
- ax.set_xlim(0, nx)
- ax.set_ylim(0, ny)
- ax.axis("off")
- ax.invert_yaxis()
- fig.tight_layout()
- return fig, ax
-
-
-def visualize_scoremaps(image, scmap):
- ny, nx = np.shape(image)[:2]
- fig, ax = form_figure(nx, ny)
- ax.imshow(image)
- ax.imshow(scmap, alpha=0.5)
- return fig, ax
-
-
-def visualize_locrefs(image, scmap, locref_x, locref_y, step=5, zoom_width=0):
- fig, ax = visualize_scoremaps(image, scmap)
- X, Y = np.meshgrid(np.arange(locref_x.shape[1]), np.arange(locref_x.shape[0]))
- M = np.zeros(locref_x.shape, dtype=bool)
- M[scmap < 0.5] = True
- U = np.ma.masked_array(locref_x, mask=M)
- V = np.ma.masked_array(locref_y, mask=M)
- ax.quiver(
- X[::step, ::step],
- Y[::step, ::step],
- U[::step, ::step],
- V[::step, ::step],
- color="r",
- units="x",
- scale_units="xy",
- scale=1,
- angles="xy",
- )
- if zoom_width > 0:
- maxloc = np.unravel_index(np.argmax(scmap), scmap.shape)
- ax.set_xlim(maxloc[1] - zoom_width, maxloc[1] + zoom_width)
- ax.set_ylim(maxloc[0] + zoom_width, maxloc[0] - zoom_width)
- return fig, ax
-
-
-def visualize_paf(image, paf, step=5, colors=None):
- ny, nx = np.shape(image)[:2]
- fig, ax = form_figure(nx, ny)
- ax.imshow(image)
- n_fields = paf.shape[2]
- if colors is None:
- colors = ["r"] * n_fields
- for n in range(n_fields):
- U = paf[:, :, n, 0]
- V = paf[:, :, n, 1]
- X, Y = np.meshgrid(np.arange(U.shape[1]), np.arange(U.shape[0]))
- M = np.zeros(U.shape, dtype=bool)
- M[U**2 + V**2 < 0.5 * 0.5**2] = True
- U = np.ma.masked_array(U, mask=M)
- V = np.ma.masked_array(V, mask=M)
- ax.quiver(
- X[::step, ::step],
- Y[::step, ::step],
- U[::step, ::step],
- V[::step, ::step],
- scale=50,
- headaxislength=4,
- alpha=1,
- width=0.002,
- color=colors[n],
- angles="xy",
- )
- return fig, ax
-
-
def _save_individual_subplots(fig, axes, labels, output_path):
for ax, label in zip(axes, labels):
extent = ax.get_tightbbox(fig.canvas.renderer).transformed(
diff --git a/deeplabcut/pose_tracking_pytorch/apis.py b/deeplabcut/pose_tracking_pytorch/apis.py
index 0e3671e16c..07a9fc13d8 100644
--- a/deeplabcut/pose_tracking_pytorch/apis.py
+++ b/deeplabcut/pose_tracking_pytorch/apis.py
@@ -11,30 +11,30 @@
def transformer_reID(
- config,
- videos,
- videotype="",
- shuffle=1,
- trainingsetindex=0,
- track_method="ellipse",
- n_tracks=None,
- n_triplets=1000,
- train_epochs=100,
- train_frac=0.8,
- modelprefix="",
- destfolder=None,
+ config: str,
+ videos: list[str],
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ track_method: str = "ellipse",
+ n_tracks: int | None = None,
+ n_triplets: int = 1000,
+ train_epochs: int = 100,
+ train_frac: float = 0.8,
+ modelprefix: str = "",
+ destfolder: str = None,
):
"""
Enables tracking with transformer.
Substeps include:
+ - Mines triplets from tracklets in videos (from another tracker)
+ - These triplets are later used to tran a transformer with triplet loss
+ - The transformer derived appearance similarity is then used as a stitching loss
+ when tracklets are stitched during tracking.
- - Mines triplets from tracklets in videos (from another tracker)
- - These triplets are later used to tran a transformer with triplet loss
- - The transformer derived appearance similarity is then used as a stitching loss when tracklets are
- stitched during tracking.
-
- Outputs: The tracklet file is saved in the same folder where the non-transformer tracklet file is stored.
+ Outputs: The tracklet file is saved in the same folder where the non-transformer
+ tracklet file is stored.
Parameters
----------
@@ -42,11 +42,14 @@ def transformer_reID(
Full path of the config.yaml file as a string.
videos: list
- A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored.
+ A list of strings containing the full paths to videos for analysis or a path to
+ the directory, where all the videos with same extension are stored.
videotype: string, optional
- Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed.
- If left unspecified, videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept.
+ Checks for the extension of the video in case the input to the video is a
+ directory. Only videos with this extension are analyzed.
+ If left unspecified, videos with common extensions ('avi', 'mp4', 'mov', 'mpeg',
+ 'mkv') are kept.
shuffle : int, optional
which shuffle to use
@@ -74,8 +77,15 @@ def transformer_reID(
--------
Training model for one video based on ellipse-tracker derived tracklets
- >>> deeplabcut.transformer_reID(path_config_file,[''/home/alex/video.mp4'],track_method="ellipse")
-
+ >>> config = "/home/users/.../dlc-project-2025-01-01/config.yaml"
+ >>> videos = ['/home/alex/video.mp4']
+ >>> deeplabcut.transformer_reID(config, videos, shuffle=1, track_method="ellipse")
+ >>> deeplabcut.create_labeled_video(
+ >>> config,
+ >>> videos,
+ >>> shuffle=1,
+ >>> track_method="transformer",
+ >>> )
--------
"""
@@ -94,7 +104,7 @@ def transformer_reID(
modelprefix=modelprefix,
)
- deeplabcut.pose_estimation_tensorflow.create_tracking_dataset(
+ deeplabcut.compat.create_tracking_dataset(
config,
videos,
track_method,
diff --git a/deeplabcut/pose_tracking_pytorch/create_dataset.py b/deeplabcut/pose_tracking_pytorch/create_dataset.py
index 9660d85256..b42f44ab93 100644
--- a/deeplabcut/pose_tracking_pytorch/create_dataset.py
+++ b/deeplabcut/pose_tracking_pytorch/create_dataset.py
@@ -13,7 +13,7 @@
import os
import pickle
import shelve
-from deeplabcut.pose_estimation_tensorflow.lib import trackingutils
+from deeplabcut.core import trackingutils
from deeplabcut.refine_training_dataset.stitch import TrackletStitcher
from pathlib import Path
from .tracking_utils.preprocessing import query_feature_by_coord_in_img_space
@@ -33,8 +33,7 @@ def save_train_triplets(feature_fname, triplets, out_name):
feature_dict = shelve.open(feature_fname, protocol=pickle.DEFAULT_PROTOCOL)
- nframes = len(feature_dict.keys())
-
+ nframes = max(len(feature_dict.keys()), 2)
zfill_width = int(np.ceil(np.log10(nframes)))
for triplet in triplets:
diff --git a/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py b/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py
index 6f1597157e..a0b2df5a31 100644
--- a/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py
+++ b/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py
@@ -58,6 +58,6 @@ def query_feature_by_coord_in_img_space(feature_dict, frame_id, ref_coord):
diff = coordinates - ref_coord
diff[np.where(np.logical_or(diff > 9000, diff < 0))] = np.nan
- match_id = np.argmin(np.nanmean(diff, axis=(1, 2)))
-
+ masked_means = np.ma.masked_invalid(np.nanmean(diff, axis=(1, 2)))
+ match_id = np.argmin(masked_means)
return features[match_id]
diff --git a/deeplabcut/refine_training_dataset/outlier_frames.py b/deeplabcut/refine_training_dataset/outlier_frames.py
index 7d0124461b..2ec2465afa 100644
--- a/deeplabcut/refine_training_dataset/outlier_frames.py
+++ b/deeplabcut/refine_training_dataset/outlier_frames.py
@@ -23,7 +23,7 @@
import statsmodels.api as sm
from skimage.util import img_as_ubyte
-from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils
+from deeplabcut.core import inferenceutils
from deeplabcut.utils import (
auxiliaryfunctions,
auxfun_multianimal,
@@ -235,7 +235,7 @@ def extract_outlier_frames(
outlieralgorithm: str, optional, default="jump".
String specifying the algorithm used to detect the outliers.
- * ``'Fitting'`` fits a Auto Regressive Integrated Moving Average model to the
+ * ``'fitting'`` fits an Auto Regressive Integrated Moving Average model to the
data and computes the distance to the estimated data. Larger distances than
epsilon are then potentially identified as outliers
* ``'jump'`` identifies larger jumps than 'epsilon' in any body part
diff --git a/deeplabcut/refine_training_dataset/stitch.py b/deeplabcut/refine_training_dataset/stitch.py
index cec322ef02..274802370e 100644
--- a/deeplabcut/refine_training_dataset/stitch.py
+++ b/deeplabcut/refine_training_dataset/stitch.py
@@ -25,7 +25,7 @@
import deeplabcut
from deeplabcut.utils.auxfun_videos import VideoWriter
from functools import partial
-from deeplabcut.pose_estimation_tensorflow.lib.trackingutils import (
+from deeplabcut.core.trackingutils import (
calc_iou,
TRACK_METHODS,
)
@@ -125,7 +125,7 @@ def centroid(self):
return self._centroid
def _update_centroid(self):
- like = self.data[..., 2:3]
+ like = self.data[..., 2:3] + 1e-10 # Avoid division by zero in very uncertain tracklets
self._centroid = np.nansum(self.xy * like, axis=1) / np.nansum(like, axis=1)
@property
@@ -1159,7 +1159,7 @@ def stitch_tracklets(
if n_tracks is None:
n_tracks = len(animal_names)
- DLCscorer, _ = deeplabcut.utils.auxiliaryfunctions.GetScorerName(
+ DLCscorer, _ = deeplabcut.utils.auxiliaryfunctions.get_scorer_name(
cfg,
shuffle,
cfg["TrainingFraction"][trainingsetindex],
diff --git a/deeplabcut/utils/auxfun_models.py b/deeplabcut/utils/auxfun_models.py
index 0dba614550..b8c46e1f9c 100644
--- a/deeplabcut/utils/auxfun_models.py
+++ b/deeplabcut/utils/auxfun_models.py
@@ -19,7 +19,6 @@
"""
import os
-import tensorflow as tf
from pathlib import Path
from deeplabcut.utils import auxiliaryfunctions
@@ -46,7 +45,6 @@
def check_for_weights(modeltype, parent_path):
"""gets local path to network weights and checks if they are present. If not, downloads them from tensorflow.org"""
-
if modeltype not in MODELTYPE_FILEPATH_MAP.keys():
print(
"Currently ResNet (50, 101, 152), MobilenetV2 (1, 0.75, 0.5 and 0.35) and EfficientNet (b0-b6) are supported, please change 'resnet' entry in config.yaml!"
@@ -159,6 +157,8 @@ def tarfilenamecutting(tarf):
def set_visible_devices(gputouse: int):
+ import tensorflow as tf
+
physical_devices = tf.config.list_physical_devices("GPU")
n_devices = len(physical_devices)
if gputouse >= n_devices:
diff --git a/deeplabcut/utils/auxfun_multianimal.py b/deeplabcut/utils/auxfun_multianimal.py
index fa337444fa..5f0314f887 100644
--- a/deeplabcut/utils/auxfun_multianimal.py
+++ b/deeplabcut/utils/auxfun_multianimal.py
@@ -33,7 +33,7 @@
from deeplabcut.utils import auxiliaryfunctions, conversioncode
from deeplabcut.generate_training_dataset import trainingsetmanipulation
-from deeplabcut.pose_estimation_tensorflow.lib.trackingutils import TRACK_METHODS
+from deeplabcut.core.trackingutils import TRACK_METHODS
def reorder_individuals_in_df(df: pd.DataFrame, order: list) -> pd.DataFrame:
@@ -147,7 +147,7 @@ def prune_paf_graph(list_of_edges, desired_n_edges=None, average_degree=None):
)
while True:
- g = nx.Graph(random.sample(G.edges, desired_n_edges))
+ g = nx.Graph(random.sample(list(G.edges), desired_n_edges))
if len(g.nodes) == n_nodes and nx.is_connected(g):
print("Valid subgraph found...")
break
diff --git a/deeplabcut/utils/auxfun_videos.py b/deeplabcut/utils/auxfun_videos.py
index c6fd1bd47b..6a937d1263 100644
--- a/deeplabcut/utils/auxfun_videos.py
+++ b/deeplabcut/utils/auxfun_videos.py
@@ -159,6 +159,24 @@ def get_bbox(self, relative=False):
y2 = int(self._height * y2)
return x1, x2, y1, y2
+ def set_bbox(self, x1, x2, y1, y2, relative=False):
+ if x2 <= x1 or y2 <= y1:
+ raise ValueError(
+ f"Coordinates look wrong... " f"Ensure {x1} < {x2} and {y1} < {y2}."
+ )
+ if not relative:
+ x1 /= self._width
+ x2 /= self._width
+ y1 /= self._height
+ y2 /= self._height
+ bbox = x1, x2, y1, y2
+ if any(coord > 1 for coord in bbox):
+ warnings.warn(
+ "Bounding box larger than the video... " "Clipping to video dimensions."
+ )
+ bbox = tuple(map(lambda x: min(x, 1), bbox))
+ self._bbox = bbox
+
@property
def fps(self):
return self._fps
@@ -205,24 +223,6 @@ def __init__(self, video_path, codec="h264", dpi=100, fps=None):
if fps:
self.fps = fps
- def set_bbox(self, x1, x2, y1, y2, relative=False):
- if x2 <= x1 or y2 <= y1:
- raise ValueError(
- f"Coordinates look wrong... " f"Ensure {x1} < {x2} and {y1} < {y2}."
- )
- if not relative:
- x1 /= self._width
- x2 /= self._width
- y1 /= self._height
- y2 /= self._height
- bbox = x1, x2, y1, y2
- if any(coord > 1 for coord in bbox):
- warnings.warn(
- "Bounding box larger than the video... " "Clipping to video dimensions."
- )
- bbox = tuple(map(lambda x: min(x, 1), bbox))
- self._bbox = bbox
-
def shorten(
self, start, end, suffix="short", dest_folder=None, validate_inputs=True
):
@@ -322,6 +322,21 @@ def crop(self, suffix="crop", dest_folder=None):
subprocess.call(command, shell=True)
return output_path
+ def rotate(self, angle, rotatecw="Arbitrary", suffix="rotated", dest_folder=None):
+ output_path = self.make_output_path(suffix, dest_folder)
+ command = f'ffmpeg -n -i "{self.video_path}" -vf '
+ if rotatecw == "Arbitrary":
+ angle = np.deg2rad(angle)
+ command += f'rotate={angle} '
+ elif rotatecw == "Yes":
+ command += 'transpose=1 '
+ else:
+ raise ValueError("Unknown rotation direction.")
+
+ command += f'-c:a copy "{output_path}"'
+ subprocess.call(command, shell=True)
+ return output_path
+
def rescale(
self,
width,
@@ -560,6 +575,47 @@ def DownSampleVideo(
return writer.rescale(width, height, rotatecw, angle, outsuffix, outpath)
+def rotate_video(vname, angle, rotatecw="Arbitrary", outsuffix="rotated", outpath=None):
+ """
+ Auxiliary function to rotate a video and output it to the same folder with "outsuffix" appended in its name.
+ Angle is in degrees.
+
+ Returns the full path to the rotated video!
+
+ Parameter
+ ----------
+ vname : string
+ A string containing the full path of the video.
+
+ angle: float
+ Angle to rotate by in degrees. Negative values rotate counter-clockwise.
+
+ rotatecw: str
+ Default "Arbitrary", rotates clockwise if "Yes", "Arbitrary" for arbitrary rotation by specified angle.
+
+ outsuffix: str
+ Suffix for output videoname (see example).
+
+ outpath: str
+ Output path for saving video to (by default will be the same folder as the video)
+
+ Examples
+ ----------
+
+ Linux/MacOs
+ >>> deeplabcut.rotate_video('/data/videos/mouse1.avi',angle=90)
+
+ Rotates the video by 90 degrees and saves it in /data/videos as mouse1rotated.avi
+
+ Windows:
+ >>> shortenedvideoname=deeplabcut.rotate_video('C:\\yourusername\\rig-95\\Videos\\reachingvideo1.avi', angle=180,rotatecw='Yes')
+
+ Rotates the video by 180 degrees and saves it in C:\\yourusername\\rig-95\\Videos as reachingvideo1rotated.avi
+ """
+ writer = VideoWriter(vname)
+ return writer.rotate(angle, rotatecw, outsuffix, outpath)
+
+
def draw_bbox(video):
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector, Button
diff --git a/deeplabcut/utils/auxiliaryfunctions.py b/deeplabcut/utils/auxiliaryfunctions.py
index 7226bc06c6..817211b53d 100644
--- a/deeplabcut/utils/auxiliaryfunctions.py
+++ b/deeplabcut/utils/auxiliaryfunctions.py
@@ -17,20 +17,24 @@
https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""
+from __future__ import annotations
import os
import typing
-from typing import List
import pickle
import warnings
from pathlib import Path
+from typing import List
+
import numpy as np
import pandas as pd
import ruamel.yaml.representer
import yaml
from ruamel.yaml import YAML
-from deeplabcut.pose_estimation_tensorflow.lib.trackingutils import TRACK_METHODS
-from deeplabcut.utils import auxfun_videos
+
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.trackingutils import TRACK_METHODS
+from deeplabcut.utils import auxfun_videos, auxfun_multianimal
def create_config_template(multianimal=False):
@@ -39,105 +43,120 @@ def create_config_template(multianimal=False):
"""
if multianimal:
yaml_str = """\
- # Project definitions (do not edit)
- Task:
- scorer:
- date:
- multianimalproject:
- identity:
- \n
- # Project path (change when moving around)
- project_path:
- \n
- # Annotation data set configuration (and individual video cropping parameters)
- video_sets:
- individuals:
- uniquebodyparts:
- multianimalbodyparts:
- bodyparts:
- \n
- # Fraction of video to start/stop when extracting frames for labeling/refinement
- start:
- stop:
- numframes2pick:
- \n
- # Plotting configuration
- skeleton:
- skeleton_color:
- pcutoff:
- dotsize:
- alphavalue:
- colormap:
- \n
- # Training,Evaluation and Analysis configuration
- TrainingFraction:
- iteration:
- default_net_type:
- default_augmenter:
- default_track_method:
- snapshotindex:
- batch_size:
- \n
- # Cropping Parameters (for analysis and outlier frame detection)
- cropping:
- #if cropping is true for analysis, then set the values here:
- x1:
- x2:
- y1:
- y2:
- \n
- # Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
- corner2move2:
- move2corner:
+# Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+\n
+# Project path (change when moving around)
+project_path:
+\n
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+\n
+# Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+individuals:
+uniquebodyparts:
+multianimalbodyparts:
+bodyparts:
+\n
+# Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+\n
+# Plotting configuration
+skeleton:
+skeleton_color:
+pcutoff:
+dotsize:
+alphavalue:
+colormap:
+\n
+# Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+default_track_method:
+snapshotindex:
+detector_snapshotindex:
+batch_size:
+\n
+# Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+#if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+\n
+# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+\n
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
"""
else:
yaml_str = """\
- # Project definitions (do not edit)
- Task:
- scorer:
- date:
- multianimalproject:
- identity:
- \n
- # Project path (change when moving around)
- project_path:
- \n
- # Annotation data set configuration (and individual video cropping parameters)
- video_sets:
- bodyparts:
- \n
- # Fraction of video to start/stop when extracting frames for labeling/refinement
- start:
- stop:
- numframes2pick:
- \n
- # Plotting configuration
- skeleton:
- skeleton_color:
- pcutoff:
- dotsize:
- alphavalue:
- colormap:
- \n
- # Training,Evaluation and Analysis configuration
- TrainingFraction:
- iteration:
- default_net_type:
- default_augmenter:
- snapshotindex:
- batch_size:
- \n
- # Cropping Parameters (for analysis and outlier frame detection)
- cropping:
- #if cropping is true for analysis, then set the values here:
- x1:
- x2:
- y1:
- y2:
- \n
- # Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
- corner2move2:
- move2corner:
+# Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+\n
+# Project path (change when moving around)
+project_path:
+\n
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+\n
+# Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+\n
+# Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+\n
+# Plotting configuration
+skeleton:
+skeleton_color:
+pcutoff:
+dotsize:
+alphavalue:
+colormap:
+\n
+# Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+detector_snapshotindex:
+batch_size:
+detector_batch_size:
+\n
+# Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+#if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+\n
+# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+\n
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
"""
ruamelFile = YAML()
@@ -151,27 +170,27 @@ def create_config_template_3d():
"""
yaml_str = """\
# Project definitions (do not edit)
- Task:
- scorer:
- date:
- \n
+Task:
+scorer:
+date:
+\n
# Project path (change when moving around)
- project_path:
- \n
+project_path:
+\n
# Plotting configuration
- skeleton: # Note that the pairs must be defined, as you want them linked!
- skeleton_color:
- pcutoff:
- colormap:
- dotsize:
- alphaValue:
- markerType:
- markerColor:
- \n
+skeleton: # Note that the pairs must be defined, as you want them linked!
+skeleton_color:
+pcutoff:
+colormap:
+dotsize:
+alphaValue:
+markerType:
+markerColor:
+\n
# Number of cameras, camera names, path of the config files, shuffle index and trainingsetindex used to analyze videos:
- num_cameras:
- camera_names:
- scorername_3d: # Enter the scorer name for the 3D output
+num_cameras:
+camera_names:
+scorername_3d: # Enter the scorer name for the 3D output
"""
ruamelFile_3d = YAML()
cfg_file_3d = ruamelFile_3d.load(yaml_str)
@@ -188,7 +207,18 @@ def read_config(configname):
try:
with open(path, "r") as f:
cfg = ruamelFile.load(f)
- curr_dir = os.path.dirname(configname)
+ curr_dir = str(Path(configname).parent.resolve())
+
+ if cfg.get("engine") is None:
+ cfg["engine"] = Engine.TF.aliases[0]
+ write_config(configname, cfg)
+
+ if cfg.get("detector_snapshotindex") is None:
+ cfg["detector_snapshotindex"] = -1
+
+ if cfg.get("detector_batch_size") is None:
+ cfg["detector_batch_size"] = 1
+
if cfg["project_path"] != curr_dir:
cfg["project_path"] = curr_dir
write_config(configname, cfg)
@@ -206,7 +236,7 @@ def read_config(configname):
else:
raise FileNotFoundError(
- "Config file is not found. Please make sure that the file exists and/or that you passed the path of the config file correctly!"
+ f"Config file at {path} not found. Please make sure that the file exists and/or that you passed the path of the config file correctly!"
)
return cfg
@@ -271,6 +301,42 @@ def edit_config(configname, edits, output_name=""):
return cfg
+def get_bodyparts(cfg: dict) -> typing.List[str]:
+ """
+ Args:
+ cfg: a project configuration file
+
+ Returns: bodyparts listed in the project (does not include the unique_bodyparts entry)
+ """
+ if cfg.get("multianimalproject", False):
+ (
+ _,
+ _,
+ multianimal_bodyparts,
+ ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
+ return multianimal_bodyparts
+
+ return cfg["bodyparts"]
+
+
+def get_unique_bodyparts(cfg: dict) -> typing.List[str]:
+ """
+ Args:
+ cfg: a project configuration file
+
+ Returns: all unique bodyparts listed in the project
+ """
+ if cfg.get("multianimalproject", False):
+ (
+ _,
+ unique_bodyparts,
+ _,
+ ) = auxfun_multianimal.extractindividualsandbodyparts(cfg)
+ return unique_bodyparts
+
+ return []
+
+
def write_config_3d(configname, cfg):
"""
Write structured 3D config file.
@@ -380,7 +446,8 @@ def get_list_of_videos(
if isinstance(videotype, str):
videotype = [videotype]
-
+ if not videotype:
+ videotype = auxfun_videos.SUPPORTED_VIDEOS
# filter list of videos
videos = [
v
@@ -439,6 +506,53 @@ def grab_files_in_folder(folder, ext="", relative=True):
yield file if relative else os.path.join(folder, file)
+def filter_files_by_patterns(
+ folder: str | Path,
+ start_patterns: set[str] | None = None,
+ contain_patterns: set[str] | None = None,
+ end_patterns: set[str] | None = None,
+) -> List[Path]:
+ """
+ Filters files in a folder based on start, contain, and end patterns.
+
+ Args:
+ folder (str | Path): The folder to search for files.
+
+ start_patterns (Set[str] | None): Patterns the filenames should start with.
+ If None or empty, this pattern is not taken into account.
+
+ contain_patterns (set[str]): Patterns the filenames should contain.
+ If None or empty, this pattern is not taken into account.
+
+ end_patterns (set[str]): Patterns the filenames should end with.
+ If None or empty, this pattern is not taken into account.
+
+ Returns:
+ List[Path]: List of files that match the criteria.
+ """
+ folder = Path(folder) # Ensure the folder is a Path object
+ if not folder.is_dir():
+ raise ValueError(f"{folder} is not a valid directory.")
+
+ # Filter files based on the given patterns
+ matching_files = [
+ file
+ for file in folder.iterdir()
+ if file.is_file()
+ and (
+ not start_patterns
+ or any(file.name.startswith(start) for start in start_patterns)
+ )
+ and (
+ not contain_patterns
+ or any(contain in file.name for contain in contain_patterns)
+ )
+ and (not end_patterns or any(file.name.endswith(end) for end in end_patterns))
+ ]
+
+ return matching_files
+
+
def get_video_list(filename, videopath, videtype):
"""Get list of videos in a path (if filetype == all), otherwise just a specific file."""
videos = list(grab_files_in_folder(videopath, videtype))
@@ -454,7 +568,7 @@ def get_video_list(filename, videopath, videtype):
## Various functions to get filenames, foldernames etc. based on configuration parameters.
-def get_training_set_folder(cfg):
+def get_training_set_folder(cfg: dict) -> Path:
"""Training Set folder for config file based on parameters"""
Task = cfg["Task"]
date = cfg["date"]
@@ -486,16 +600,82 @@ def get_data_and_metadata_filenames(trainingsetfolder, trainFraction, shuffle, c
+ str(shuffle)
+ ".mat",
)
+
return datafn, metadatafn
-def get_model_folder(trainFraction, shuffle, cfg, modelprefix=""):
+def get_model_folder(
+ trainFraction: float,
+ shuffle: int,
+ cfg: dict,
+ modelprefix: str = "",
+ engine: Engine = Engine.TF,
+) -> Path:
+ """
+ Args:
+ trainFraction: the training fraction (as defined in the project configuration)
+ for which to get the model folder
+ shuffle: the index of the shuffle for which to get the model folder
+ cfg: the project configuration
+ modelprefix: The name of the folder
+ engine: The engine for which we want the model folder. Defaults to `tensorflow`
+ for backwards compatibility with DeepLabCut 2.X
+
+ Returns:
+ the relative path from the project root to the folder containing the model files
+ for a shuffle (configuration files, snapshots, training logs, ...)
+ """
+ proj_id = f"{cfg['Task']}{cfg['date']}"
+ return Path(
+ modelprefix,
+ engine.model_folder_name,
+ f"iteration-{cfg['iteration']}",
+ f"{proj_id}-trainset{int(trainFraction * 100)}shuffle{shuffle}",
+ )
+
+
+def get_evaluation_folder(
+ trainFraction: float,
+ shuffle: int,
+ cfg: dict,
+ engine: Engine | None = None,
+ modelprefix: str = "",
+) -> Path:
+ """
+ Args:
+ trainFraction: the training fraction (as defined in the project configuration)
+ for which to get the evaluation folder
+ shuffle: the index of the shuffle for which to get the evaluation folder
+ cfg: the project configuration
+ engine: The engine for which we want the model folder. Defaults to None,
+ which automatically gets the engine for the shuffle from the training
+ dataset metadata file.
+ modelprefix: The name of the folder
+
+ Returns:
+ the relative path from the project root to the folder containing the model files
+ for a shuffle (configuration files, snapshots, training logs, ...)
+ """
+ if engine is None:
+ from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine
+
+ engine = get_shuffle_engine(
+ cfg=cfg,
+ trainingsetindex=cfg["TrainingFraction"].index(trainFraction),
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
Task = cfg["Task"]
date = cfg["date"]
iterate = "iteration-" + str(cfg["iteration"])
+ if "eval_prefix" in cfg:
+ eval_prefix = cfg["eval_prefix"]
+ else:
+ eval_prefix = engine.results_folder_name
return Path(
modelprefix,
- "dlc-models",
+ eval_prefix,
iterate,
Task
+ date
@@ -529,27 +709,6 @@ def get_snapshots_from_folder(train_folder: Path) -> List[str]:
return sorted(snapshot_names, key=lambda name: int(name.split("-")[1]))
-def get_evaluation_folder(trainFraction, shuffle, cfg, modelprefix=""):
- Task = cfg["Task"]
- date = cfg["date"]
- iterate = "iteration-" + str(cfg["iteration"])
- if "eval_prefix" in cfg:
- eval_prefix = cfg["eval_prefix"]
- else:
- eval_prefix = "evaluation-results"
- return Path(
- modelprefix,
- eval_prefix,
- iterate,
- Task
- + date
- + "-trainset"
- + str(int(trainFraction * 100))
- + "shuffle"
- + str(shuffle),
- )
-
-
def get_deeplabcut_path():
"""Get path of where deeplabcut is currently running"""
import importlib.util
@@ -588,41 +747,69 @@ def form_data_containers(df, bodyparts):
def get_scorer_name(
- cfg, shuffle, trainFraction, trainingsiterations="unknown", modelprefix=""
+ cfg: dict,
+ shuffle: int,
+ trainFraction: float,
+ trainingsiterations: str | int = "unknown",
+ modelprefix: str = "",
+ engine: Engine | None = None,
):
"""Extract the scorer/network name for a particular shuffle, training fraction, etc.
+ If the engine is not specified, determines which to use from
Returns tuple of DLCscorer, DLCscorerlegacy (old naming convention)
"""
+ if engine is None:
+ from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine
+
+ engine = get_shuffle_engine(
+ cfg=cfg,
+ trainingsetindex=cfg["TrainingFraction"].index(trainFraction),
+ shuffle=shuffle,
+ modelprefix=modelprefix,
+ )
+
+ if engine == Engine.PYTORCH:
+ from deeplabcut.pose_estimation_pytorch.apis.utils import get_scorer_name
+
+ snapshot_index = None
+ if isinstance(trainingsiterations, int):
+ snapshot_index = trainingsiterations
+
+ dlc3_scorer = get_scorer_name(
+ cfg=cfg,
+ shuffle=shuffle,
+ train_fraction=trainFraction,
+ snapshot_index=snapshot_index,
+ detector_index=None,
+ modelprefix=modelprefix,
+ )
+ return dlc3_scorer, dlc3_scorer
Task = cfg["Task"]
date = cfg["date"]
if trainingsiterations == "unknown":
- snapshotindex = cfg["snapshotindex"]
- if cfg["snapshotindex"] == "all":
- print(
- "Changing snapshotindext to the last one -- plotting, videomaking, etc. should not be performed for all indices. For more selectivity enter the ordinal number of the snapshot you want (ie. 4 for the fifth) in the config file."
- )
- snapshotindex = -1
- else:
- snapshotindex = cfg["snapshotindex"]
-
- train_folder = (
- Path(cfg["project_path"])
- / get_model_folder(trainFraction, shuffle, cfg, modelprefix=modelprefix)
- / "train"
+ snapshotindex = get_snapshot_index_for_scorer(
+ "snapshotindex", cfg["snapshotindex"]
+ )
+ model_folder = get_model_folder(
+ trainFraction, shuffle, cfg, engine=engine, modelprefix=modelprefix
)
+ train_folder = Path(cfg["project_path"]) / model_folder / "train"
snapshot_names = get_snapshots_from_folder(train_folder)
-
snapshot_name = snapshot_names[snapshotindex]
trainingsiterations = (snapshot_name.split(os.sep)[-1]).split("-")[-1]
dlc_cfg = read_plainconfig(
os.path.join(
cfg["project_path"],
- str(get_model_folder(trainFraction, shuffle, cfg, modelprefix=modelprefix)),
+ str(
+ get_model_folder(
+ trainFraction, shuffle, cfg, engine=engine, modelprefix=modelprefix
+ )
+ ),
"train",
- "pose_cfg.yaml",
+ engine.pose_cfg_name,
)
)
# ABBREVIATE NETWORK NAMES -- esp. for mobilenet!
@@ -635,6 +822,8 @@ def get_scorer_name(
netname = "mobnet_" + str(int(float(dlc_cfg["net_type"].split("_")[-1]) * 100))
elif "efficientnet" in dlc_cfg["net_type"]:
netname = "effnet_" + dlc_cfg["net_type"].split("-")[1]
+ else:
+ raise ValueError(f"Failed to abbreviate network name: {dlc_cfg['net_type']}")
scorer = (
"DLC_"
@@ -729,30 +918,47 @@ def check_if_not_evaluated(folder, DLCscorer, DLCscorerlegacy, snapshot):
return True, dataname, DLCscorer
+def find_video_full_data(folder, videoname, scorer):
+ scorer_legacy = scorer.replace("DLC", "DeepCut")
+ full_files = filter_files_by_patterns(
+ folder=folder,
+ start_patterns={videoname + scorer, videoname + scorer_legacy},
+ contain_patterns={"full"},
+ end_patterns={"pickle"},
+ )
+ if not full_files:
+ raise FileNotFoundError(
+ f"No full data found in {folder} "
+ f"for video {videoname} and scorer {scorer}."
+ )
+ return full_files[0]
+
+
def find_video_metadata(folder, videoname, scorer):
"""For backward compatibility, let us search the substring 'meta'"""
scorer_legacy = scorer.replace("DLC", "DeepCut")
- meta = [
- file
- for file in grab_files_in_folder(folder, "pickle")
- if "meta" in file
- and (
- file.startswith(videoname + scorer)
- or file.startswith(videoname + scorer_legacy)
- )
- ]
- if not len(meta):
+ meta_files = filter_files_by_patterns(
+ folder=folder,
+ start_patterns={videoname + scorer, videoname + scorer_legacy},
+ contain_patterns={"meta"},
+ end_patterns={"pickle"},
+ )
+ if not meta_files:
raise FileNotFoundError(
f"No metadata found in {folder} "
f"for video {videoname} and scorer {scorer}."
)
- return os.path.join(folder, meta[0])
+ return meta_files[0]
def load_video_metadata(folder, videoname, scorer):
return read_pickle(find_video_metadata(folder, videoname, scorer))
+def load_video_full_data(folder, videoname, scorer):
+ return read_pickle(find_video_full_data(folder, videoname, scorer))
+
+
def find_analyzed_data(folder, videoname, scorer, filtered=False, track_method=""):
"""Find potential data files from the hints given to the function."""
scorer_legacy = scorer.replace("DLC", "DeepCut")
@@ -855,6 +1061,18 @@ def find_next_unlabeled_folder(config_path, verbose=False):
return next_folder
+def get_snapshot_index_for_scorer(name: str, index: int | str) -> int:
+ if index == "all":
+ print(
+ f"Changing {name} to the last one -- plotting, videomaking, etc. should "
+ "not be performed for all indices. For more selectivity enter the ordinal "
+ "number of the snapshot you want (ie. 4 for the fifth) in the config file."
+ )
+ return -1
+
+ return index
+
+
# aliases for backwards-compatibility.
SaveData = save_data
SaveMetadata = save_metadata
@@ -869,3 +1087,5 @@ def find_next_unlabeled_folder(config_path, verbose=False):
CheckifPostProcessing = check_if_post_processing
CheckifNotAnalyzed = check_if_not_analyzed
CheckifNotEvaluated = check_if_not_evaluated
+GetEvaluationFolder = get_evaluation_folder
+GetModelFolder = get_model_folder
diff --git a/deeplabcut/utils/auxiliaryfunctions_3d.py b/deeplabcut/utils/auxiliaryfunctions_3d.py
index 22d31b4555..2483623d31 100644
--- a/deeplabcut/utils/auxiliaryfunctions_3d.py
+++ b/deeplabcut/utils/auxiliaryfunctions_3d.py
@@ -322,12 +322,17 @@ def _associate_paired_view_tracks(tracklets1, tracklets2, F):
_t1 = np.c_[_t1, np.ones((*_t1.shape[:2], 1))]
_t2 = np.c_[_t2, np.ones((*_t2.shape[:2], 1))]
- # cost for any point in time of t1 being the same
- # any point in time of t2
- cost = np.abs(np.nansum(np.matmul(_t1, F) * _t2, axis=2))
+ try:
+ # cost for any point in time of t1 being the same
+ # any point in time of t2
+ cost = np.abs(np.nansum(np.matmul(_t1, F) * _t2, axis=2))
+
+ # Get average cost of the entire track
+ cost = cost.mean()
+ except:
+ # typically when dim 2 differs, with uniquebodyparts
+ cost = 100000.0
- # Get average cost of the entire track
- cost = cost.mean()
costs[i, j] = cost
match_inds = linear_sum_assignment(np.abs(costs))
diff --git a/deeplabcut/utils/frameselectiontools.py b/deeplabcut/utils/frameselectiontools.py
index dd2201e40f..947d32d59e 100644
--- a/deeplabcut/utils/frameselectiontools.py
+++ b/deeplabcut/utils/frameselectiontools.py
@@ -262,7 +262,9 @@ def KmeansbasedFrameselectioncv2(
if batchsize > nframes:
batchsize = nframes // 2
- allocated = False
+ ny_ = np.round(ny * ratio).astype(int)
+ nx_ = np.round(nx * ratio).astype(int)
+ DATA = np.empty((nframes, ny_, nx_ * 3 if color else nx_))
if len(Index) >= numframes2pick:
if (
np.mean(np.diff(Index)) > 1
@@ -282,13 +284,6 @@ def KmeansbasedFrameselectioncv2(
interpolation=cv2.INTER_NEAREST,
)
) # color trafo not necessary; lack thereof improves speed.
- if (
- not allocated
- ): #'DATA' not in locals(): #allocate memory in first pass
- DATA = np.empty(
- (nframes, np.shape(image)[0], np.shape(image)[1] * 3)
- )
- allocated = True
DATA[counter, :, :] = np.hstack(
[image[:, :, 0], image[:, :, 1], image[:, :, 2]]
)
@@ -306,13 +301,6 @@ def KmeansbasedFrameselectioncv2(
interpolation=cv2.INTER_NEAREST,
)
) # color trafo not necessary; lack thereof improves speed.
- if (
- not allocated
- ): #'DATA' not in locals(): #allocate memory in first pass
- DATA = np.empty(
- (nframes, np.shape(image)[0], np.shape(image)[1])
- )
- allocated = True
DATA[counter, :, :] = np.mean(image, 2)
else:
print("Extracting and downsampling...", nframes, " frames from the video.")
@@ -329,13 +317,6 @@ def KmeansbasedFrameselectioncv2(
interpolation=cv2.INTER_NEAREST,
)
) # color trafo not necessary; lack thereof improves speed.
- if (
- not allocated
- ): #'DATA' not in locals(): #allocate memory in first pass
- DATA = np.empty(
- (nframes, np.shape(image)[0], np.shape(image)[1] * 3)
- )
- allocated = True
DATA[counter, :, :] = np.hstack(
[image[:, :, 0], image[:, :, 1], image[:, :, 2]]
)
@@ -352,13 +333,6 @@ def KmeansbasedFrameselectioncv2(
interpolation=cv2.INTER_NEAREST,
)
) # color trafo not necessary; lack thereof improves speed.
- if (
- not allocated
- ): #'DATA' not in locals(): #allocate memory in first pass
- DATA = np.empty(
- (nframes, np.shape(image)[0], np.shape(image)[1])
- )
- allocated = True
DATA[counter, :, :] = np.mean(image, 2)
print("Kmeans clustering ... (this might take a while)")
diff --git a/deeplabcut/utils/make_labeled_video.py b/deeplabcut/utils/make_labeled_video.py
index 8d8c53d017..6a36cb1793 100644
--- a/deeplabcut/utils/make_labeled_video.py
+++ b/deeplabcut/utils/make_labeled_video.py
@@ -20,6 +20,7 @@
Hao Wu, hwu01@g.harvard.edu contributed the original OpenCV class. Thanks!
You can find the directory for your ffmpeg bindings by: "find / | grep ffmpeg" and then setting it.
"""
+from __future__ import annotations
import argparse
import os
@@ -28,27 +29,28 @@
# Dependencies
####################################################
import os.path
-from pathlib import Path
from functools import partial
-from multiprocessing import Pool, get_start_method
-from typing import Iterable, Callable, List, Optional, Union
+from multiprocessing import get_start_method, Pool
+from pathlib import Path
+from typing import Callable, Iterable, List, Optional, Union
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
+from matplotlib import patches
from matplotlib.animation import FFMpegWriter
from matplotlib.collections import LineCollection
-from skimage.draw import disk, line_aa, set_color
+from skimage.draw import disk, line_aa, set_color, rectangle_perimeter
from skimage.util import img_as_ubyte
from tqdm import trange
-from deeplabcut.modelzoo.utils import parse_available_supermodels
-from deeplabcut.pose_estimation_tensorflow.config import load_config
-from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal, visualization
+
+from deeplabcut.core.engine import Engine
+from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions, visualization
+from deeplabcut.utils.auxfun_videos import VideoWriter
from deeplabcut.utils.video_processor import (
VideoProcessorCV as vp,
) # used to CreateVideo
-from deeplabcut.utils.auxfun_videos import VideoWriter
def get_segment_indices(bodyparts2connect, all_bpts):
@@ -85,6 +87,10 @@ def CreateVideo(
displaycropped,
color_by,
confidence_to_alpha=None,
+ plot_bboxes=True,
+ bboxes_list=None,
+ bboxes_pcutoff=0.6,
+ bboxes_color: tuple | None = None,
):
"""Creating individual frames with labeled body parts and making a video"""
bpts = Dataframe.columns.get_level_values("bodyparts")
@@ -151,12 +157,37 @@ def CreateVideo(
C = colorclass.to_rgba(np.linspace(0, 1, nindividuals))
colors = (C[:, :3] * 255).astype(np.uint8)
+ if bboxes_color is None:
+ bboxes_color = (255, 0, 0)
+
with np.errstate(invalid="ignore"):
for index in trange(min(nframes, len(Dataframe))):
image = clip.load_frame()
if displaycropped:
image = image[y1:y2, x1:x2]
+ # Draw bounding boxes if required and present
+ if plot_bboxes and bboxes_list:
+ bboxes = bboxes_list[index]["bboxes"]
+ bbox_scores = bboxes_list[index]["bbox_scores"]
+ n_bboxes = bboxes.shape[0]
+ for i in range(n_bboxes):
+ bbox = bboxes[i, :]
+ x, y = bbox[0], bbox[1]
+ x += x1
+ y += y1
+ w, h = bbox[2], bbox[3]
+ confidence = bbox_scores[i]
+ if confidence < bboxes_pcutoff:
+ continue
+ rect_coords = rectangle_perimeter(start=(y, x), extent=(h, w))
+
+ set_color(
+ image,
+ rect_coords,
+ bboxes_color,
+ )
+
# Draw the skeleton for specific bodyparts to be connected as
# specified in the config file
if draw_skeleton:
@@ -224,10 +255,12 @@ def CreateVideoSlow(
draw_skeleton,
displaycropped,
color_by,
+ plot_bboxes=True,
+ bboxes_list=None,
+ bboxes_pcutoff=0.6,
+ bboxes_color: str | None = None,
):
"""Creating individual frames with labeled body parts and making a video"""
- # scorer=np.unique(Dataframe.columns.get_level_values(0))[0]
- # bodyparts2plot = list(np.unique(Dataframe.columns.get_level_values(1)))
if displaycropped:
ny, nx = y2 - y1, x2 - x1
@@ -284,6 +317,9 @@ def CreateVideoSlow(
else:
colors = visualization.get_cmap(nbodyparts, name=colormap)
+ if bboxes_color is None:
+ bboxes_color = "red"
+
nframes_digits = int(np.ceil(np.log10(nframes)))
if nframes_digits > 9:
raise Exception(
@@ -312,6 +348,29 @@ def CreateVideoSlow(
image = image[y1:y2, x1:x2]
ax.imshow(image)
+ # Draw bounding boxes of required and present
+ if plot_bboxes and bboxes_list:
+ bboxes = bboxes_list[index]["bboxes"]
+ bbox_scores = bboxes_list[index]["bbox_scores"]
+ n_bboxes = bboxes.shape[0]
+ for i in range(n_bboxes):
+ bbox = bboxes[i, :]
+ bbox_origin = (bbox[0], bbox[1])
+ (bbox_width, bbox_height) = (bbox[2], bbox[3])
+ bbox_confidence = bbox_scores[i]
+ if bbox_confidence < bboxes_pcutoff:
+ continue
+ rectangle = patches.Rectangle(
+ bbox_origin,
+ bbox_width,
+ bbox_height,
+ linewidth=1,
+ edgecolor=bboxes_color,
+ facecolor="none",
+ )
+ ax.add_patch(rectangle)
+
+ # Draw skeleton
if draw_skeleton:
for bpt1, bpt2 in bpts2connect:
if np.all(df_likelihood[[bpt1, bpt2], index] > pcutoff):
@@ -322,6 +381,7 @@ def CreateVideoSlow(
alpha=alphavalue,
)
+ # Draw bodyparts
for ind, num_bp, num_ind in bpts2color:
if df_likelihood[ind, index] > pcutoff:
if color_by == "bodypart":
@@ -360,37 +420,39 @@ def CreateVideoSlow(
def create_labeled_video(
- config,
- videos,
- videotype="",
- shuffle=1,
- trainingsetindex=0,
- filtered=False,
- fastmode=True,
- save_frames=False,
- keypoints_only=False,
- Frames2plot=None,
- displayedbodyparts="all",
- displayedindividuals="all",
- codec="mp4v",
- outputframerate=None,
- destfolder=None,
- draw_skeleton=False,
- trailpoints=0,
- displaycropped=False,
- color_by="bodypart",
- modelprefix="",
- init_weights="",
- track_method="",
- superanimal_name="",
- pcutoff=0.6,
- skeleton=[],
- skeleton_color="white",
- dotsize=8,
- colormap="rainbow",
- alphavalue=0.5,
- overwrite=False,
+ config: str,
+ videos: list[str],
+ videotype: str = "",
+ shuffle: int = 1,
+ trainingsetindex: int = 0,
+ filtered: bool = False,
+ fastmode: bool = True,
+ save_frames: bool = False,
+ keypoints_only: bool = False,
+ Frames2plot: list[int] | None = None,
+ displayedbodyparts: list[str] | str = "all",
+ displayedindividuals: list[str] | str = "all",
+ codec: str = "mp4v",
+ outputframerate: int | None = None,
+ destfolder: Path | str | None = None,
+ draw_skeleton: bool = False,
+ trailpoints: int = 0,
+ displaycropped: bool = False,
+ color_by: str = "bodypart",
+ modelprefix: str = "",
+ init_weights: str = "",
+ track_method: str = "",
+ superanimal_name: str = "",
+ pcutoff: float | None = None,
+ skeleton: list = [],
+ skeleton_color: str = "white",
+ dotsize: int = 8,
+ colormap: str = "rainbow",
+ alphavalue: float = 0.5,
+ overwrite: bool = False,
confidence_to_alpha: Union[bool, Callable[[float], float]] = False,
+ plot_bboxes: bool = True,
+ bboxes_pcutoff: float | None = None,
):
"""Labels the bodyparts in a video.
@@ -467,7 +529,7 @@ def create_labeled_video(
mode with saving frames.) If ``None``, which results in the original video
rate.
- destfolder: string or None, optional, default=None
+ destfolder: Path, string or None, optional, default=None
Specifies the destination folder that was used for storing analysis data. If
``None``, the path of the video file is used.
@@ -502,6 +564,25 @@ def create_labeled_video(
For multiple animals, must be either 'box', 'skeleton', or 'ellipse' and will
be taken from the config.yaml file if none is given.
+ superanimal_name: str, optional, default=""
+ Name of the superanimal model.
+
+ pcutoff: float, optional, default=None
+ Overrides the pcutoff set in the project configuration to plot the trajectories.
+
+ skeleton: list, optional, default=[],
+
+ skeleton_color: string, optional, default="white",
+ Color for the skeleton
+
+ dotsize, int, optional, default=8,
+ Size of label dots tu use
+
+ colormap: str, optional, default="rainbow",
+ Colormap to use for the labels
+
+ alphavalue: float, optional, default=0.5,
+
overwrite: bool, optional, default=False
If ``True`` overwrites existing labeled videos.
@@ -511,6 +592,12 @@ def create_labeled_video(
keypoint will be set as a function of its score: alpha = f(score). The default
function used when True is f(x) = max(0, (x - pcutoff)/(1 - pcutoff)).
+ plot_bboxes: bool, optional, default=True
+ If using Pytorch and in Top-Down mode, setting this to true will also plot the bounding boxes
+
+ bboxes_pcutoff, float, optional, default=None:
+ If plotting bounding boxes, this overrides the bboxes_pcutoff set in the model configuration.
+
Returns
-------
results : list[bool]
@@ -561,17 +648,61 @@ def create_labeled_video(
)
"""
if config == "":
- pass
+ if pcutoff is None:
+ pcutoff = 0.6
+ if bboxes_pcutoff is None:
+ bboxes_pcutoff = 0.6
+
+ individuals = [""]
+ uniquebodyparts = []
else:
cfg = auxiliaryfunctions.read_config(config)
- trainFraction = cfg["TrainingFraction"][trainingsetindex]
+ train_fraction = cfg["TrainingFraction"][trainingsetindex]
track_method = auxfun_multianimal.get_track_method(
cfg, track_method=track_method
)
+ if pcutoff is None:
+ pcutoff = cfg["pcutoff"]
+
+ # Get individuals from the config
+ individuals = cfg.get("individuals", [""])
+ uniquebodyparts = cfg.get("uniquebodyparts", [])
+
+ # Only for PyTorch engine - check if the shuffle was fine-tuned from a
+ # SuperAnimal model with memory replay -> SuperAnimal bodyparts must be used
+ model_folder = auxiliaryfunctions.get_model_folder(
+ train_fraction,
+ shuffle,
+ cfg,
+ modelprefix,
+ engine=Engine.PYTORCH,
+ )
+ model_config_path = (
+ Path(config).parent / model_folder / "train" / Engine.PYTORCH.pose_cfg_name
+ )
+ if model_config_path.exists():
+ model_config = auxiliaryfunctions.read_plainconfig(str(model_config_path))
+ if (
+ model_config["train_settings"]
+ .get("weight_init", {})
+ .get("memory_replay", False)
+ ):
+ superanimal_name = model_config["train_settings"]["weight_init"][
+ "dataset"
+ ]
+ if bboxes_pcutoff is None:
+ bboxes_pcutoff = (
+ model_config.get("detector", {})
+ .get("model", {})
+ .get("box_score_thresh", 0.6)
+ )
+ else:
+ if bboxes_pcutoff is None:
+ bboxes_pcutoff = 0.6
if init_weights == "":
- DLCscorer, DLCscorerlegacy = auxiliaryfunctions.GetScorerName(
- cfg, shuffle, trainFraction, modelprefix=modelprefix
+ DLCscorer, DLCscorerlegacy = auxiliaryfunctions.get_scorer_name(
+ cfg, shuffle, train_fraction, modelprefix=modelprefix
) # automatically loads corresponding model (even training iteration based on snapshot index)
else:
DLCscorer = "DLC_" + Path(init_weights).stem
@@ -587,17 +718,16 @@ def create_labeled_video(
if superanimal_name != "":
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
- supermodels = parse_available_supermodels()
- test_cfg = load_config(
+ test_cfg = auxiliaryfunctions.read_plainconfig(
os.path.join(
dlc_root_path,
- "pose_estimation_tensorflow",
- "superanimal_configs",
- supermodels[superanimal_name],
+ "modelzoo",
+ "project_configs",
+ f"{superanimal_name}.yaml",
)
)
- bodyparts = test_cfg["all_joints_names"]
+ bodyparts = test_cfg["bodyparts"]
cfg = {
"skeleton": skeleton,
"skeleton_color": skeleton_color,
@@ -605,6 +735,10 @@ def create_labeled_video(
"dotsize": dotsize,
"alphavalue": alphavalue,
"colormap": colormap,
+ "bodyparts": bodyparts,
+ "multianimalbodyparts": bodyparts,
+ "individuals": individuals,
+ "uniquebodyparts": uniquebodyparts,
}
else:
bodyparts = (
@@ -657,7 +791,10 @@ def create_labeled_video(
keypoints_only,
overwrite,
init_weights=init_weights,
+ pcutoff=pcutoff,
confidence_to_alpha=confidence_to_alpha,
+ plot_bboxes=plot_bboxes,
+ bboxes_pcutoff=bboxes_pcutoff,
)
if get_start_method() == "fork":
@@ -697,7 +834,10 @@ def proc_video(
overwrite,
video,
init_weights="",
+ pcutoff: float | None = None,
confidence_to_alpha: Optional[Callable[[float], float]] = None,
+ plot_bboxes: bool = True,
+ bboxes_pcutoff: float = 0.6,
):
"""Helper function for create_videos
@@ -710,10 +850,13 @@ def proc_video(
result : bool
``True`` if a video is successfully created.
"""
- videofolder = Path(video).parents[0]
+ videofolder = Path(video).parent
if destfolder is None:
destfolder = videofolder # where your folder with videos is.
+ if pcutoff is None:
+ pcutoff = cfg["pcutoff"]
+
auxiliaryfunctions.attempt_to_make_folder(destfolder)
os.chdir(destfolder) # THE VIDEO IS STILL IN THE VIDEO FOLDER
@@ -749,7 +892,10 @@ def proc_video(
s = "_id" if color_by == "individual" else "_bp"
else:
s = ""
- videooutname = filepath.replace(".h5", f"{s}_labeled.mp4")
+
+ videooutname = filepath.replace(
+ ".h5", f"{s}_p{int(100 * pcutoff)}_labeled.mp4"
+ )
if os.path.isfile(videooutname) and not overwrite:
print("Labeled video already created. Skipping...")
return
@@ -770,6 +916,24 @@ def proc_video(
if bp in bodyparts
]
+ # The full data file is not created for single-animal TensorFlow models
+ try:
+ full_data = auxiliaryfunctions.load_video_full_data(
+ destfolder, vname, DLCscorer
+ )
+ frames_dict = {
+ int(key.replace("frame", "")): value
+ for key, value in full_data.items()
+ if key.startswith("frame") and key[5:].isdigit()
+ }
+ bboxes_list = None
+ if "bboxes" in frames_dict.get(min(frames_dict.keys()), {}):
+ bboxes_list = [
+ frames_dict[key] for key in sorted(frames_dict.keys())
+ ]
+ except FileNotFoundError:
+ bboxes_list = None
+
if keypoints_only:
# Mask rather than drop unwanted bodyparts to ensure consistent coloring
mask = df.columns.get_level_values("bodyparts").isin(bodyparts)
@@ -783,7 +947,7 @@ def proc_video(
df,
videooutname,
inds,
- cfg["pcutoff"],
+ pcutoff,
cfg["dotsize"],
cfg["alphavalue"],
skeleton_color=skeleton_color,
@@ -805,7 +969,7 @@ def proc_video(
cfg["dotsize"],
cfg["colormap"],
cfg["alphavalue"],
- cfg["pcutoff"],
+ pcutoff,
trailpoints,
cropping,
x1,
@@ -821,10 +985,13 @@ def proc_video(
draw_skeleton,
displaycropped,
color_by,
+ plot_bboxes=plot_bboxes,
+ bboxes_list=bboxes_list,
+ bboxes_pcutoff=bboxes_pcutoff,
)
clip.close()
else:
- _create_labeled_video(
+ create_video(
video,
filepath,
keypoints2show=labeled_bpts,
@@ -832,7 +999,7 @@ def proc_video(
bbox=(x1, x2, y1, y2),
codec=codec,
output_path=videooutname,
- pcutoff=cfg["pcutoff"],
+ pcutoff=pcutoff,
dotsize=cfg["dotsize"],
cmap=cfg["colormap"],
color_by=color_by,
@@ -842,7 +1009,11 @@ def proc_video(
fps=outputframerate,
display_cropped=displaycropped,
confidence_to_alpha=confidence_to_alpha,
+ plot_bboxes=plot_bboxes,
+ bboxes_list=bboxes_list,
+ bboxes_pcutoff=bboxes_pcutoff,
)
+
return True
except FileNotFoundError as e:
@@ -850,15 +1021,15 @@ def proc_video(
return False
-def _create_labeled_video(
+def create_video(
video,
h5file,
keypoints2show="all",
animals2show="all",
skeleton_edges=None,
pcutoff=0.6,
- dotsize=8,
- cmap="cool",
+ dotsize=6,
+ cmap="rainbow",
color_by="bodypart",
skeleton_color="k",
trailpoints=0,
@@ -868,6 +1039,10 @@ def _create_labeled_video(
fps=None,
output_path="",
confidence_to_alpha=None,
+ plot_bboxes=True,
+ bboxes_list=None,
+ bboxes_pcutoff=0.6,
+ bboxes_color: tuple | None = None,
):
if color_by not in ("bodypart", "individual"):
raise ValueError("`color_by` should be either 'bodypart' or 'individual'.")
@@ -922,9 +1097,17 @@ def _create_labeled_video(
display_cropped,
color_by,
confidence_to_alpha=confidence_to_alpha,
+ plot_bboxes=plot_bboxes,
+ bboxes_list=bboxes_list,
+ bboxes_pcutoff=bboxes_pcutoff,
+ bboxes_color=bboxes_color,
)
+# for backwards compatibility
+_create_labeled_video = create_video
+
+
def create_video_with_keypoints_only(
df,
output_name,
@@ -1016,6 +1199,7 @@ def create_video_with_all_detections(
destfolder=None,
modelprefix="",
confidence_to_alpha: Union[bool, Callable[[float], float]] = False,
+ plot_bboxes: bool = True,
):
"""
Create a video labeled with all the detections stored in a '*_full.pickle' file.
@@ -1055,10 +1239,15 @@ def create_video_with_all_detections(
defined as a function f: [0, 1] -> [0, 1] such that the alpha value for a
keypoint will be set as a function of its score: alpha = f(score). The default
function used when True is f(x) = x.
+
+ plot_bboxes: bool, optional (default=True)
+ If detections were produced using a Pytorch Top-Down model, setting this parameter to True will also plot
+ the bounding boxes generated by the detector.
"""
- from deeplabcut.pose_estimation_tensorflow.lib.inferenceutils import Assembler
import re
+ from deeplabcut.core.inferenceutils import Assembler
+
cfg = auxiliaryfunctions.read_config(config)
trainFraction = cfg["TrainingFraction"][trainingsetindex]
DLCscorername, _ = auxiliaryfunctions.get_scorer_name(
@@ -1125,12 +1314,47 @@ def create_video_with_all_detections(
clip = vp(fname=video, sname=outputname, codec="mp4v")
ny, nx = clip.height(), clip.width()
+ bboxes_pcutoff = (
+ metadata.get("data", {})
+ .get("pytorch-config", {})
+ .get("detector", {})
+ .get("model", {})
+ .get("box_score_thresh", 0.6)
+ )
+ bboxes_color = (255, 0, 0)
+
for n in trange(clip.nframes):
frame = clip.load_frame()
if frame is None:
continue
try:
ind = frames.index(n)
+
+ # Draw bounding boxes of required and present
+ if plot_bboxes and "bboxes" in data[frame_names[ind]]:
+ bboxes = data[frame_names[ind]]["bboxes"]
+ bbox_scores = data[frame_names[ind]]["bbox_scores"]
+ n_bboxes = bboxes.shape[0]
+ for i in range(n_bboxes):
+ bbox = bboxes[i, :]
+ x, y = bbox[0], bbox[1]
+ x += x1
+ y += y1
+ w, h = bbox[2], bbox[3]
+ confidence = bbox_scores[i]
+ if confidence < bboxes_pcutoff:
+ continue
+ rect_coords = rectangle_perimeter(
+ start=(y, x), extent=(h, w)
+ )
+
+ set_color(
+ frame,
+ rect_coords,
+ bboxes_color,
+ )
+
+ # Draw detected bodyparts
dets = Assembler._flatten_detections(data[frame_names[ind]])
for det in dets:
if det.label not in bpts or det.confidence < pcutoff:
@@ -1164,6 +1388,7 @@ def create_video_with_all_detections(
def _create_video_from_tracks(video, tracks, destfolder, output_name, pcutoff, scale=1):
import subprocess
+
from tqdm import tqdm
if not os.path.isdir(destfolder):
diff --git a/deeplabcut/utils/multiprocessing.py b/deeplabcut/utils/multiprocessing.py
new file mode 100644
index 0000000000..3515b73125
--- /dev/null
+++ b/deeplabcut/utils/multiprocessing.py
@@ -0,0 +1,54 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""
+DeepLabCut2.2 Toolbox (deeplabcut.org)
+© A. & M. Mathis Labs
+https://github.com/DeepLabCut/DeepLabCut
+Please see AUTHORS for contributors.
+
+https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+Licensed under GNU Lesser General Public License v3.0
+"""
+import multiprocessing
+
+
+def _wrapper(func, queue, *args, **kwargs):
+ try:
+ result = func(*args, **kwargs)
+ queue.put(result) # Pass the result back via the queue
+ except Exception as e:
+ queue.put(e) # Pass any exception back via the queue
+
+
+def call_with_timeout(func, timeout, *args, **kwargs):
+ queue = multiprocessing.Queue()
+ process = multiprocessing.Process(
+ target=_wrapper, args=(func, queue, *args), kwargs=kwargs
+ )
+ process.start()
+ process.join(timeout)
+
+ if process.is_alive():
+ process.terminate() # Forcefully terminate the process
+ process.join()
+ raise TimeoutError(
+ f"Function {func.__name__} did not complete within {timeout} seconds."
+ )
+
+ if not queue.empty():
+ result = queue.get()
+ if isinstance(result, Exception):
+ raise result # Re-raise the exception if it occurred in the function
+ return result
+ else:
+ raise TimeoutError(
+ f"Function {func.__name__} completed but did not return a result."
+ )
diff --git a/deeplabcut/utils/plotting.py b/deeplabcut/utils/plotting.py
index 1320791a34..d059c14391 100644
--- a/deeplabcut/utils/plotting.py
+++ b/deeplabcut/utils/plotting.py
@@ -17,6 +17,7 @@
https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""
+from __future__ import annotations
import argparse
import os
@@ -32,7 +33,7 @@
import matplotlib.pyplot as plt
import numpy as np
-from deeplabcut.pose_estimation_tensorflow.lib import crossvalutils
+from deeplabcut.core import crossvalutils
from deeplabcut.utils import auxiliaryfunctions, auxfun_multianimal, visualization
@@ -187,6 +188,7 @@ def plot_trajectories(
resolution=100,
linewidth=1.0,
track_method="",
+ pcutoff: float | None = None,
):
"""Plots the trajectories of various bodyparts across the video.
@@ -251,6 +253,9 @@ def plot_trajectories(
For multiple animals, must be either 'box', 'skeleton', or 'ellipse' and will
be taken from the config.yaml file if none is given.
+ pcutoff: string, optional, default=None
+ Overrides the pcutoff set in the project configuration to plot the trajectories.
+
Returns
-------
None
@@ -266,6 +271,10 @@ def plot_trajectories(
)
"""
cfg = auxiliaryfunctions.read_config(config)
+
+ if pcutoff is None:
+ pcutoff = cfg["pcutoff"]
+
track_method = auxfun_multianimal.get_track_method(cfg, track_method=track_method)
trainFraction = cfg["TrainingFraction"][trainingsetindex]
@@ -308,7 +317,7 @@ def plot_trajectories(
linewidth,
cfg["colormap"],
cfg["alphavalue"],
- cfg["pcutoff"],
+ pcutoff,
suffix,
imagetype,
tmpfolder,
diff --git a/deeplabcut/utils/pseudo_label.py b/deeplabcut/utils/pseudo_label.py
new file mode 100644
index 0000000000..826fb75734
--- /dev/null
+++ b/deeplabcut/utils/pseudo_label.py
@@ -0,0 +1,506 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import glob
+import json
+import os
+from collections import defaultdict
+from pathlib import Path
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from scipy.spatial import distance
+from scipy.spatial.distance import cdist
+
+import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.modelzoo.generalized_data_converter.datasets import (
+ MaDLCDataFrame,
+ SingleDLCDataFrame,
+)
+from deeplabcut.pose_estimation_pytorch.apis.utils import get_inference_runners
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ select_device,
+ update_config,
+)
+
+
+class NumpyEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, np.ndarray):
+ return obj.tolist() # Convert ndarray to list
+ return super().default(obj)
+
+
+def xywh2xyxy(bbox):
+ temp_bbox = np.copy(bbox)
+ temp_bbox[2:] = temp_bbox[:2] + temp_bbox[2:]
+ return temp_bbox
+
+
+def optimal_match(gts_list, preds_list):
+ arranged_preds_list = []
+ num_gts = len(gts_list)
+ num_preds = len(preds_list)
+ cost_matrix = np.zeros((num_gts, num_preds))
+
+ for i in range(num_gts):
+ for j in range(num_preds):
+ cost_matrix[i, j] = distance.euclidean(
+ gts_list[i][..., :2].flatten(), preds_list[j][..., :2].flatten()
+ )
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
+
+ return col_ind
+
+
+def calculate_iou(box1, box2):
+ # Unpack the coordinates
+ x1_1, y1_1, x2_1, y2_1 = box1
+ x1_2, y1_2, x2_2, y2_2 = box2
+
+ # Calculate the coordinates of the intersection rectangle
+ inter_x1 = max(x1_1, x1_2)
+ inter_y1 = max(y1_1, y1_2)
+ inter_x2 = min(x2_1, x2_2)
+ inter_y2 = min(y2_1, y2_2)
+
+ # Calculate the width and height of the intersection rectangle
+ inter_width = max(0, inter_x2 - inter_x1)
+ inter_height = max(0, inter_y2 - inter_y1)
+
+ # Calculate the area of the intersection rectangle
+ inter_area = inter_width * inter_height
+
+ # Calculate the area of each bounding box
+ area_1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+ area_2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+
+ # Calculate the area of the union of the two bounding boxes
+ union_area = area_1 + area_2 - inter_area
+
+ # Calculate the IoU
+ iou = inter_area / union_area
+
+ return iou
+
+
+def video_to_frames(input_video, output_folder, cropping: list[int] | None = None):
+ # Create the output folder if it doesn't exist
+ video = cv2.VideoCapture(str(input_video))
+ # Get the frames per second (fps) of the video
+ fps = int(video.get(cv2.CAP_PROP_FPS))
+ # Initialize a frame counter
+ frame_count = 0
+ while True:
+ # Read a frame from the video
+ ret, frame = video.read()
+ # Break the loop if we have reached the end of the video
+ if not ret:
+ break
+ # Crop the frame if desired
+ if cropping is not None:
+ x1, x2, y1, y2 = cropping
+ frame = frame[y1:y2, x1:x2]
+
+ # Save the frame as an image file.
+ frame_str = str(frame_count).zfill(5)
+ frame_file = os.path.join(output_folder, "images", f"frame_{frame_str}.png")
+ cv2.imwrite(frame_file, frame)
+ # Increment the frame counter
+ frame_count += 1
+ # Release the video object and close the window (if open)
+ video.release()
+ # cv2.destroyAllWindows()
+
+
+def plot_cost_matrix(
+ matrix, gt_keypoint_names, pred_keypoint_names, conversion_plot_out_path
+):
+
+ matrix /= np.max(matrix)
+ fig, ax = plt.subplots()
+ heatmap = ax.pcolor(matrix, cmap=plt.cm.Blues, vmin=0, vmax=1)
+ ax.set_xticks(np.arange(matrix.shape[1]) + 0.5, minor=False)
+ ax.set_yticks(np.arange(matrix.shape[0]) + 0.5, minor=False)
+ ax.set_xlim(0, int(matrix.shape[1]))
+ ax.set_ylim(0, int(matrix.shape[0]))
+ ax.set_yticklabels(pred_keypoint_names, minor=False)
+ ax.set_xticklabels(gt_keypoint_names, minor=False)
+ ax.set_title("cost matrix")
+ plt.xticks(rotation=90)
+ fig = plt.gcf()
+ fig.tight_layout()
+
+ plt.savefig(conversion_plot_out_path, dpi=300)
+
+
+def keypoint_matching(
+ config_path: str | Path,
+ superanimal_name: str,
+ model_name: str,
+ detector_name: str,
+ copy_images: bool = False,
+ device: str | None = None,
+ train_file: str = "train.json",
+):
+ """Runs the keypoint matching algorithm for a DeepLabCut project
+
+ Matches project keypoints to SuperAnimal keypoints automatically, by running
+ SuperAnimal inference on all images in the dataset
+
+ Args:
+ config_path: The path of the DeepLabCut project configuration file.
+ superanimal_name: SuperAnimal dataset with which to run keypoint matching.
+ model_name: Name of the SuperAnimal pose model architecture with which to run
+ keypoint matching
+ detector_name: Name of the SuperAnimal detector architecture with which to run
+ keypoint matching
+ copy_images: When False, symlinks are created for the dataset used for keypoint
+ matching. Otherwise, images are copied from the `labeled-data` folder to the
+ folder used for keypoint matching.
+ device: The device on which to run keypoint matching.
+ train_file: The name of the file containing the labels to output.
+ """
+ config_path = Path(config_path)
+ cfg = af.read_config(str(config_path))
+ dlc_proj_root = config_path.parent
+
+ if "individuals" in cfg:
+ temp_dataset = MaDLCDataFrame(str(dlc_proj_root), "temp_dataset")
+ max_individuals = len(cfg["individuals"])
+ else:
+ temp_dataset = SingleDLCDataFrame(str(dlc_proj_root), "temp_dataset")
+ max_individuals = 1
+
+ memory_replay_folder = dlc_proj_root / "memory_replay"
+ temp_dataset.materialize(
+ str(memory_replay_folder), framework="coco", deepcopy=copy_images
+ )
+
+ # run inference on the train set
+ config = modelzoo.load_super_animal_config(
+ super_animal=superanimal_name,
+ model_name=model_name,
+ detector_name=detector_name,
+ )
+ if device is None:
+ device = select_device()
+
+ # get the SuperAnimal detector and pose model snapshot paths
+ pose_model_path = modelzoo.get_super_animal_snapshot_path(
+ dataset=superanimal_name, model_name=model_name,
+ )
+ detector_path = modelzoo.get_super_animal_snapshot_path(
+ dataset=superanimal_name, model_name=detector_name,
+ )
+
+ config = update_config(config, max_individuals, device)
+ individuals = [f"animal{i}" for i in range(max_individuals)]
+ config["metadata"]["individuals"] = individuals
+ train_file_path = os.path.join(memory_replay_folder, "annotations", train_file)
+
+ pose_runner, detector_runner = get_inference_runners(
+ config,
+ snapshot_path=pose_model_path,
+ max_individuals=max_individuals,
+ num_bodyparts=len(config["metadata"]["bodyparts"]),
+ num_unique_bodyparts=0,
+ detector_path=detector_path,
+ )
+
+ with open(train_file_path, "r") as f:
+ train_obj = json.load(f)
+
+ images = train_obj["images"]
+ annotations = train_obj["annotations"]
+ categories = train_obj["categories"]
+ image_name_to_id = {}
+ image_id_to_name = {}
+
+ image_name_to_gt = defaultdict(list)
+ image_name_to_bbox = defaultdict(list)
+ image_id_to_annotations = defaultdict(list)
+
+ for image in images:
+ # this only works with relative path as the testing image can be at a different folder
+ name = image["file_name"].split(os.sep)[-1]
+ image_name_to_id[name] = image["id"]
+ image_id_to_name[image["id"]] = name
+
+ for anno in annotations:
+ name = image_id_to_name[anno["image_id"]]
+ image_name_to_gt[name].append(anno)
+ image_name_to_bbox[name].append(anno["bbox"])
+
+ image_ids = set(image_name_to_id.values())
+ for anno in annotations:
+ image_id = anno["image_id"]
+ if anno["image_id"] in image_ids:
+ image_id_to_annotations[image_id].append(anno)
+
+ # need to support more image types
+ image_extensions = ["*.png", "*.jpg", "*.jpeg", "*.bmp", "*.gif", "*.tiff"]
+ images_in_folder = []
+ for ext in image_extensions:
+ images_in_folder.extend(
+ glob.glob(os.path.join(memory_replay_folder, "images", ext))
+ )
+
+ corresponded_images = []
+ for image in images_in_folder:
+ image_path = image
+ name = image.split(os.sep)[-1]
+ if name in image_name_to_id:
+ corresponded_images.append(image_path)
+
+ images = corresponded_images
+ bbox_gts = [
+ {"bboxes": np.array(image_name_to_bbox[image.split(os.sep)[-1]])}
+ for image in images
+ ]
+
+ pose_inputs = list(zip(images, bbox_gts))
+
+ # pose inference should return meta data for pseudo labeling
+ predictions = pose_runner.inference(pose_inputs)
+
+ with open(str(memory_replay_folder / "pseudo_predictions.json"), "w") as f:
+ json.dump(pose_inputs, f, cls=NumpyEncoder)
+
+ assert len(images) == len(predictions)
+
+ image_name_to_pred = {}
+ for image_path, prediction in zip(images, predictions):
+ name = image_path.split(os.sep)[-1]
+ image_name_to_pred[name] = prediction
+
+ pred_keypoint_names = config["metadata"]["bodyparts"]
+ num_pred_keypoints = len(pred_keypoint_names)
+ gt_keypoint_names = categories[0]["keypoints"]
+ num_gt_keypoints = len(gt_keypoint_names)
+
+ match_matrix = np.zeros((num_pred_keypoints, num_gt_keypoints))
+ match_dict = defaultdict(lambda: defaultdict(int))
+
+ for name, gts in image_name_to_gt.items():
+ bbox_gts = [np.array(gt["bbox"]) for gt in gts]
+ bbox_gts = [xywh2xyxy(e) for e in bbox_gts]
+ prediction = image_name_to_pred[name]
+ bbox_preds = [xywh2xyxy(pred) for pred in prediction["bboxes"]]
+ optimal_pred_indices = optimal_match(bbox_gts, bbox_preds)
+
+ for idx in range(len(bbox_gts)):
+ if idx == len(optimal_pred_indices):
+ break
+
+ optimal_index = optimal_pred_indices[idx]
+ matched_gt = np.array(gts[idx]["keypoints"])
+ matched_pred = prediction["bodyparts"][optimal_index]
+ matched_gt = matched_gt.reshape(num_gt_keypoints, -1)
+ matched_pred = matched_pred.reshape(num_pred_keypoints, -1)
+
+ pair_distance = cdist(matched_pred, matched_gt)
+ row_ind, column_ind = linear_sum_assignment(pair_distance)
+ for row, column in zip(row_ind, column_ind):
+ pred_kpt_name = pred_keypoint_names[row]
+ anno_kpt_name = gt_keypoint_names[column]
+ match_matrix[row][column] += 1
+ match_dict[pred_kpt_name][anno_kpt_name] += 1
+
+ row_ind, column_ind = linear_sum_assignment(match_matrix * -1)
+ keypoint_mapping_list = []
+
+ conversion_matrix_out_path = os.path.join(
+ memory_replay_folder, "confusion_matrix.png"
+ )
+
+ plot_cost_matrix(
+ match_matrix, gt_keypoint_names, pred_keypoint_names, conversion_matrix_out_path
+ )
+
+ for row, column in zip(row_ind, column_ind):
+ pred_kpt_name = pred_keypoint_names[row]
+ anno_kpt_name = gt_keypoint_names[column]
+ count = match_dict[pred_kpt_name][anno_kpt_name]
+ keypoint_mapping_list.append((pred_kpt_name, anno_kpt_name, count))
+
+ keypoint_mapping_list = sorted(
+ keypoint_mapping_list, key=lambda x: x[2], reverse=True
+ )
+
+ names = [e[:2] for e in keypoint_mapping_list]
+ conversion_table = {}
+ for pred, anno in names:
+ conversion_table[pred] = anno
+
+ conversion_table_out_path = os.path.join(
+ memory_replay_folder, "conversion_table.csv"
+ )
+ with open(conversion_table_out_path, "w") as f:
+ out = "gt, MasterName\n"
+ for name in pred_keypoint_names:
+ target = name
+ source = conversion_table.get(target, "")
+ out += f"{source}, {target}\n"
+ f.write(out)
+
+
+# this is to generate a coco project as an intermediate data
+def dlc3predictions_2_annotation_from_video(
+ predictions,
+ dest_proj_folder,
+ bodyparts,
+ superanimal_name,
+ pose_threshold=0.0,
+ bbox_threshold=0.0,
+):
+ """
+ For video adaptation, we also need to create a coco project
+ dlc3 predictions:
+
+ list of dictionary
+ [{
+ bodyparts:[] # (n_individuals, n_kpts, 3)
+ bboxes: [] # (n_individuals, 4) -> x,y,w,h
+ }]
+
+ coco result is a list of dictionary
+ # i might get a minimal version that works with my script
+
+ category_id:
+ image_id: []
+ image_path: []
+ keypoints: []
+ score: []
+ bbox: []
+
+ """
+
+ category_id = 1 # the default for superanimal. But it might be changed
+
+ images = []
+ annotations = []
+ categories = []
+ annotation_id = 0
+ image_folder = os.path.join(dest_proj_folder, "images")
+
+ # video_to_frames function by default outputs png or jpg
+ image_paths = sorted(glob.glob(os.path.join(image_folder, "*.png")))
+
+ # skipping every 4 frames should speed up and not impact the performance
+ predictions, image_paths = predictions[::10], image_paths[::10]
+
+ # Since the inference API does not return the image path, I assume the predictions are provided in the same order as the frames in the video.
+ assert len(image_paths) == len(
+ predictions
+ ), f"number of images must be equal to number of predictions. image_paths: {len(image_paths)} , predictions: {len(predictions)}"
+ new_predictions = []
+
+ num_kpts = len(bodyparts)
+
+ if not superanimal_name.startswith("superanimal_"):
+ raise ValueError("not supporting non superanimal model video adaptation yet")
+
+ category_name = superanimal_name[len("superanimal_"):]
+ categories = [
+ {
+ "name": category_name,
+ "id": 1,
+ "supercategory": "animal",
+ "keypoints": bodyparts,
+ }
+ ]
+
+ assert len(predictions) == len(image_paths)
+ imageid2annotations = defaultdict(list)
+ for image_id, (prediction, image_path) in enumerate(zip(predictions, image_paths)):
+ image_obj = cv2.imread(image_path)
+ height, width, channels = image_obj.shape
+ imagename = image_path.split(os.sep)[-1]
+ image = {
+ "id": image_id,
+ "file_name": imagename,
+ "width": width,
+ "height": height,
+ }
+
+ # iterate through individuals if there are many
+
+ assert (
+ len(prediction["bodyparts"])
+ == len(prediction["bboxes"])
+ == len(prediction["bbox_scores"])
+ )
+ for pose, bbox, bbox_score in zip(
+ prediction["bodyparts"], prediction["bboxes"], prediction["bbox_scores"]
+ ):
+ if (
+ np.all(np.array(pose) <= 0)
+ or len(bbox) == 0
+ or bbox_score < bbox_threshold
+ ):
+ continue
+ imageid2annotations[image_id].append(pose)
+ pose = np.array(pose)
+ bbox = np.array(bbox)
+
+ mask = pose[:, -1] < pose_threshold
+
+ pose[mask] = 0
+
+ # by default all visible
+ pose[:, -1] = 2
+ bbox_confidence = bbox[-1]
+
+ keypoints = list(pose.reshape(-1))
+ keypoints = [float(num) for num in keypoints]
+ # bbox here is x,y,w,h from dlc3
+ bbox = [float(num) for num in bbox][:4]
+
+ anno = {
+ "category_id": int(category_id),
+ "keypoints": keypoints,
+ "num_keypoints": len(keypoints) // 3,
+ "image_id": int(image_id),
+ "bbox": bbox,
+ "area": float(bbox[-2] * bbox[-3]),
+ "iscrowd": 0,
+ "id": int(annotation_id),
+ }
+
+ annotation_id += 1
+ annotations.append(anno)
+
+ # this is to prevent images that do not have annotations
+ if len(imageid2annotations[image_id]) > 0:
+ images.append(image)
+
+ train_obj = {"images": images, "annotations": annotations, "categories": categories}
+
+ test_annotations = []
+
+ # just use the first 10 image annotations for test
+ test_obj = {
+ "images": images[:10],
+ "annotations": annotations[:10],
+ "categories": categories,
+ }
+
+ # there is no 'test' split of video adaptation. This is essentially train.json
+ with open(os.path.join(dest_proj_folder, "annotations", "test.json"), "w") as f:
+ json.dump(test_obj, f, indent=4)
+
+ with open(os.path.join(dest_proj_folder, "annotations", "train.json"), "w") as f:
+ json.dump(train_obj, f, indent=4)
diff --git a/deeplabcut/utils/visualization.py b/deeplabcut/utils/visualization.py
index 609b76f86f..cb6b4bfd6e 100644
--- a/deeplabcut/utils/visualization.py
+++ b/deeplabcut/utils/visualization.py
@@ -17,22 +17,33 @@
https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""
+from __future__ import annotations
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
+import pandas as pd
from matplotlib.collections import LineCollection
+from matplotlib.colors import Colormap
+import matplotlib.patches as patches
from skimage import io, color
from tqdm import trange
-from deeplabcut.utils import auxiliaryfunctions
+from deeplabcut.utils import auxiliaryfunctions, auxfun_videos
-def get_cmap(n, name="hsv"):
- """Returns a function that maps each index in 0, 1, ..., n-1 to a distinct
- RGB color; the keyword argument name must be a standard mpl colormap name."""
+def get_cmap(n: int, name: str = "hsv") -> Colormap:
+ """
+ Args:
+ n: number of distinct colors
+ name: name of matplotlib colormap
+
+ Returns:
+ A function that maps each index in 0, 1, ..., n-1 to a distinct
+ RGB color; the keyword argument name must be a standard mpl colormap name.
+ """
return plt.cm.get_cmap(name, n)
@@ -104,21 +115,73 @@ def make_labeled_image(
def make_multianimal_labeled_image(
- frame,
- coords_truth,
- coords_pred,
- probs_pred,
- colors,
- dotsize=12,
- alphavalue=0.7,
- pcutoff=0.6,
- labels=["+", ".", "x"],
- ax=None,
-):
+ frame: np.ndarray,
+ coords_truth: np.ndarray | list,
+ coords_pred: np.ndarray | list,
+ probs_pred: np.ndarray | list,
+ colors: Colormap,
+ dotsize: float | int = 12,
+ alphavalue: float = 0.7,
+ pcutoff: float = 0.6,
+ labels: list = ["+", ".", "x"],
+ ax: plt.Axes | None = None,
+ bounding_boxes: tuple[np.ndarray, np.ndarray] | None = None,
+ bboxes_cutoff: float = 0.6,
+ bboxes_color: Colormap | str | None = None,
+) -> plt.Axes:
+ """
+ Plots groundtruth labels and predictions onto the matplotlib's axes, with the specified graphical parameters.
+
+ Args:
+ frame: image
+ coords_truth: groundtruth labels
+ coords_pred: predictions
+ probs_pred: prediction probabilities
+ colors: colors for poses
+ dotsize: size of dot
+ alphavalue: transparency for the keypoints
+ pcutoff: cut-off confidence value
+ labels: labels to use for ground truth, reliable predictions, and not reliable predictions (confidence below cut-off value)
+ ax: matplotlib plot's axes object
+ bounding_boxes: bounding boxes (top-left corner, size) and their respective confidence levels,
+ bboxes_cutoff: bounding boxes confidence cutoff threshold.
+ bboxes_color: color(s) for the bounding boxes.
+ If Colormap is passed -> each bounding box will be colored into its own color from the colormap.
+ If string is passed -> all bboxes will be of string's defined color.
+ If None -> all bboxes will be colored into a default color.
+
+ Returns:
+ matplotlib Axes object with plotted labels and predictions.
+ """
+
if ax is None:
h, w, _ = np.shape(frame)
_, ax = prepare_figure_axes(w, h)
ax.imshow(frame, "gray")
+
+ if bounding_boxes is not None:
+ for i, (bbox, bbox_score) in enumerate(
+ zip(bounding_boxes[0], bounding_boxes[1])
+ ):
+ bbox_origin = (bbox[0], bbox[1])
+ (bbox_width, bbox_height) = (bbox[2], bbox[3])
+ if isinstance(bboxes_color, Colormap):
+ bbox_color = bboxes_color(i)
+ elif bboxes_color is None:
+ bbox_color = "red"
+ else:
+ bbox_color = bboxes_color
+ rectangle = patches.Rectangle(
+ bbox_origin,
+ bbox_width,
+ bbox_height,
+ linewidth=1,
+ edgecolor=bbox_color,
+ facecolor="none",
+ linestyle="--" if bbox_score < bboxes_cutoff else "-",
+ )
+ ax.add_patch(rectangle)
+
for n, data in enumerate(zip(coords_truth, coords_pred, probs_pred)):
color = colors(n)
coord_gt, coord_pred, prob_pred = data
@@ -159,7 +222,10 @@ def plot_and_save_labeled_frame(
ax,
scaling=1,
):
- image_path = os.path.join(cfg["project_path"], *DataCombined.index[ind])
+ if isinstance(DataCombined.index[ind], tuple):
+ image_path = os.path.join(cfg["project_path"], *DataCombined.index[ind])
+ else:
+ image_path = os.path.join(cfg["project_path"], DataCombined.index[ind])
frame = io.imread(image_path)
if np.ndim(frame) > 2: # color image!
h, w, numcolors = np.shape(frame)
@@ -375,3 +441,167 @@ def make_labeled_images_from_dataframe(
dpi=dpi,
)
plt.close(fig)
+
+
+def plot_evaluation_results(
+ df_combined: pd.DataFrame,
+ project_root: str,
+ scorer: str,
+ model_name: str,
+ output_folder: str,
+ in_train_set: bool,
+ plot_unique_bodyparts: bool = False,
+ mode: str = "bodypart",
+ colormap: str = "rainbow",
+ dot_size: int = 12,
+ alpha_value: float = 0.7,
+ p_cutoff: float = 0.6,
+ bounding_boxes: dict | None = None,
+ bboxes_cutoff: float = 0.6,
+ bounding_boxes_color: str = "auto",
+) -> None:
+ """
+ Creates labeled images using the results of inference, and saves them to an output
+ folder.
+
+ Args:
+ df_combined: dataframe with multiindex rows ("labeled-data", video_name,
+ image_name) and columns ("scorer", "individuals", "bodyparts", "coords").
+ There should be two scorers: scorer (for ground truth data) and model_name
+ (for prediction data)
+ project_root: the project root path
+ scorer: the name of the scorer for ground truth data in df_combined
+ model_name: the name of the model for predictions in df_combined
+ output_folder: the name of the folder where images should be saved
+ in_train_set: whether df_combined is for train set images
+ plot_unique_bodyparts: whether we should plot unique bodyparts
+ mode: one of {"bodypart", "individual"}. Determines the keypoint color grouping
+ colormap: the colormap to use for keypoints
+ dot_size: the dot size to use for keypoints
+ alpha_value: the alpha value to use for keypoints
+ p_cutoff: the p-cutoff for "confident" keypoints
+ bounding_boxes: dictionary with df_combined rows as keys and bounding boxes
+ (np array for coordinates and np array for confidence).
+ None corresponds to no bounding boxes.
+ bboxes_cutoff: bounding boxes confidence cutoff threshold.
+ bounding_boxes_color: If plotting bounding boxes, this is the color that will be used for bounding boxes.
+ If set to "auto" (default value):
+ - if mode is "bodypart", the bbox color will be a default color
+ - if mode is "individual", each individual's color will be used for its bounding box
+
+ """
+ if bounding_boxes is None:
+ bounding_boxes = {}
+
+ for row_index, row in df_combined.iterrows():
+ if isinstance(row_index, str):
+ image_rel_path = Path(row_index)
+ data_folder = image_rel_path.parent.parent.name
+ video = image_rel_path.parent.name
+ image = image_rel_path.name
+ else:
+ data_folder, video, image = row_index
+
+ image_path = Path(project_root) / data_folder / video / image
+ frame = auxfun_videos.imread(str(image_path), mode="skimage")
+
+ row_multi = row.loc[
+ (slice(None), row.index.get_level_values("individuals") != "single")
+ ]
+ individuals = len(row_multi.index.get_level_values("individuals").unique())
+ bodyparts = len(row_multi.index.get_level_values("bodyparts").unique())
+ df_gt = row_multi[scorer]
+ df_predictions = row_multi[model_name]
+
+ # Shape (num_individuals, num_bodyparts, xy)
+ ground_truth = df_gt.to_numpy().reshape((individuals, bodyparts, 2))
+ predictions = df_predictions.to_numpy().reshape((individuals, bodyparts, 3))
+
+ bboxes = bounding_boxes.get(row_index)
+
+ if plot_unique_bodyparts:
+ row_unique = row.loc[
+ (slice(None), row.index.get_level_values("individuals") == "single")
+ ]
+ unique_individuals = 1
+ unique_bodyparts = len(
+ row_unique.index.get_level_values("bodyparts").unique()
+ )
+ unique_ground_truth = (
+ row_unique[scorer]
+ .to_numpy()
+ .reshape((unique_individuals, unique_bodyparts, 2))
+ )
+ unique_predictions = (
+ row_unique[model_name]
+ .to_numpy()
+ .reshape((unique_individuals, unique_bodyparts, 3))
+ )
+
+ fig, ax = create_minimal_figure()
+ h, w, _ = np.shape(frame)
+ fig.set_size_inches(w / 100, h / 100)
+ ax.set_xlim(0, w)
+ ax.set_ylim(0, h)
+ ax.invert_yaxis()
+
+ if mode == "bodypart":
+ num_colors = bodyparts
+ if plot_unique_bodyparts:
+ num_colors += unique_bodyparts
+
+ colors = get_cmap(num_colors, name=colormap)
+ predictions = predictions.swapaxes(0, 1)
+ ground_truth = ground_truth.swapaxes(0, 1)
+ elif mode == "individual":
+ colors = get_cmap(individuals + 1, name=colormap)
+ else:
+ colors = []
+
+ if bounding_boxes_color == "auto":
+ if mode == "bodypart":
+ bboxes_color = None
+ elif mode == "individual":
+ bboxes_color = get_cmap(individuals + 1, name=colormap)
+ else:
+ raise ValueError(f"Invalid mode: {mode}")
+ else:
+ bboxes_color = bounding_boxes_color
+
+ ax = make_multianimal_labeled_image(
+ frame=frame,
+ coords_truth=ground_truth,
+ coords_pred=predictions[:, :, :2],
+ probs_pred=predictions[:, :, 2:],
+ colors=colors,
+ dotsize=dot_size,
+ alphavalue=alpha_value,
+ pcutoff=p_cutoff,
+ ax=ax,
+ bounding_boxes=bboxes,
+ bboxes_cutoff=bboxes_cutoff,
+ bboxes_color=bboxes_color,
+ )
+ if plot_unique_bodyparts:
+ unique_predictions = unique_predictions.swapaxes(0, 1)
+ unique_ground_truth = unique_ground_truth.swapaxes(0, 1)
+ ax = make_multianimal_labeled_image(
+ frame=frame,
+ coords_truth=unique_ground_truth,
+ coords_pred=unique_predictions[:, :, :2],
+ probs_pred=unique_predictions[:, :, 2:],
+ colors=colors,
+ dotsize=dot_size,
+ alphavalue=alpha_value,
+ pcutoff=p_cutoff,
+ ax=ax,
+ )
+
+ save_labeled_frame(
+ fig,
+ str(image_path),
+ output_folder,
+ belongs_to_train=in_train_set,
+ )
+ erase_artists(ax)
+ plt.close()
diff --git a/deeplabcut/version.py b/deeplabcut/version.py
index 75d7e1ce07..0910e25f69 100644
--- a/deeplabcut/version.py
+++ b/deeplabcut/version.py
@@ -9,5 +9,5 @@
# Licensed under GNU Lesser General Public License v3.0
#
-__version__ = "2.3.11"
+__version__ = "3.0.0rc6"
VERSION = __version__
diff --git a/docs/Governance.md b/docs/Governance.md
index 4fe068f0db..ff7477f28e 100644
--- a/docs/Governance.md
+++ b/docs/Governance.md
@@ -1,6 +1,6 @@
(governance-model)=
# Governance Model of DeepLabCut
-(adapted from https://napari.org/docs/_sources/developers/GOVERNANCE.md.txt)
+(adapted from https://napari.org/stable/community/governance.html)
## Abstract
@@ -30,7 +30,7 @@ project in concrete ways, such as:
[GitHub pull request](https://github.com/DeepLabCut/DeepLabCut/pulls);
- reporting issues on our
[GitHub issues page](https://github.com/DeepLabCut/DeepLabCut/issues);
-- proposing a change to the documentation (http://docs.deeplabcut.org) via a
+- proposing a change to the [documentation](https://deeplabcut.github.io/DeepLabCut/README.html) via a
GitHub pull request;
- discussing the design of the `DeepLabCut` or its tutorials on in existing
[issues](https://github.com/DeepLabCut/DeepLabCut/issues) and
@@ -43,7 +43,7 @@ among other possibilities. Any community member can become a contributor, and
all are encouraged to do so. By contributing to the project, community members
can directly help to shape its future.
-Contributors are encouraged to read the [contributing guide](https://github.com/DeepLabCut/DeepLabCut/CONTRIBUTING.md).
+Contributors are encouraged to read the [contributing guide](https://github.com/DeepLabCut/DeepLabCut/blob/main/CONTRIBUTING.md).
### Core developers
diff --git a/docs/MISSION_AND_VALUES.md b/docs/MISSION_AND_VALUES.md
index 4677f07345..c82c4b7a39 100644
--- a/docs/MISSION_AND_VALUES.md
+++ b/docs/MISSION_AND_VALUES.md
@@ -3,14 +3,16 @@
This document is meant to help guide decisions about the future of `DeepLabCut`, be it in terms of
whether to accept new functionality, changes to the styling of the code or graphical user interfaces (GUI),
-or whether to take on new dependencies, when to break into other repos, among other things. It serves as a point of reference for core developers actively working on the project, and an introduction for
+or whether to take on new dependencies, when to break into other repos, among other things. It serves as a point of
+reference for core developers actively working on the project, and an introduction for
newcomers who want to learn a little more about where the project is going and what the team's
-values are. You can also learn more about how the project is managed by looking at our [governance model](governance-model).
+values are. You can also learn more about how the project is managed by looking at our
+[governance model](governance-model).
## Our founding principles
-The founding DeepLabCut team came together around a shared vision for building the first open-source animal pose estimation framework
-that is:
+The founding DeepLabCut team came together around a shared vision for building the first open-source animal pose
+estimation framework that is:
- user defined pose estimation - i.e. species or object agnostic.
- access to SOTA deep learning models that can be swiftly re-trained for customized applications
@@ -18,39 +20,57 @@ that is:
- scalable (project focused for ease of portability and sharability)
-As the project has grown we've turned these original principles into the mission statement and set of values that we described below.
+As the project has grown we've turned these original principles into the mission statement and set of values that we
+described below.
## Our mission
-DeepLabCut aims to be **the animal pose software package for Python** and to **provide access to deep learning-based pose estimation for people to use in their daily work** without the need to be able to program in a deep learning framework.
-We hope to accomplish this by:
+DeepLabCut aims to be **the animal pose software package for Python** and to **provide access to deep learning-based
+pose estimation for people to use in their daily work** without the need to be able to program in a deep learning
+framework. We hope to accomplish this by:
-- being **easy to use and install**. We are careful in taking on new dependencies, sometimes making them optional, and aim support a fully (Python) packaged installation that works cross-platform.
+- being **easy to use and install**. We are careful in taking on new dependencies, sometimes making them optional, and
+aim support a fully (Python) packaged installation that works cross-platform.
-- being **well-documented** with **comprehensive tutorials and examples**. All functions in our API have thorough docstrings clarifying expected inputs and outputs, and we maintain a separate [tutorials and information website](http://deeplabcut.org).
+- being **well-documented** with **comprehensive tutorials and examples**. All functions in our API have thorough
+docstrings clarifying expected inputs and outputs, and we maintain a separate
+[tutorials and information website](http://deeplabcut.org).
- providing **GUI access** to all critical functionality so DeepLabCut can be used by people without coding experience.
- being **interactive** and **highly performant** in order to support large data pipelines.
-- providing a **consistent and stable API** to enable plugin developers to build on top of DeepLabCut without their code constantly breaking and to enable advanced users to build out sophisticated Python workflows, if needed.
+- providing a **consistent and stable API** to enable plugin developers to build on top of DeepLabCut without their
+code constantly breaking and to enable advanced users to build out sophisticated Python workflows, if needed.
-- **ensuring correctness**. We strive for complete test coverage of both the code and GUI, with all code reviewed by a core developer before being included in the repository.
+- **ensuring correctness**. We strive for complete test coverage of both the code and GUI, with all code reviewed by a
+core developer before being included in the repository.
## Our values
-- We are **inclusive**. We welcome newcomers who are making their first contribution and strive to grow our most dedicated contributors into [core developers](https://github.com/orgs/DeepLabCut/teams/core-developers). We have a [Code of Conduct](https://github.com/DeepLabCut/DeepLabCut/CODE_OF_CONDUCT.md) to make DeepLabCut
+- We are **inclusive**. We welcome newcomers who are making their first contribution and strive to grow our most
+dedicated contributors into [core developers](https://github.com/orgs/DeepLabCut/teams/core-developers).
+We have a [Code of Conduct](https://github.com/DeepLabCut/DeepLabCut/blob/main/CODE_OF_CONDUCT.md) to make DeepLabCut
a welcoming place for all.
-- We are **community-engaged**. We respond to feature requests and proposals on our [issue tracker](https://github.com/DeepLabCut/DeepLabCut/issues).
+- We are **community-engaged**. We respond to feature requests and proposals on our
+- [issue tracker](https://github.com/DeepLabCut/DeepLabCut/issues).
-- We serve **scientific applications** primarily, over “consumer or commercial” pose estimation tools. This often means prioritizing core functionality support, and rejecting implementations of “flashy” features that have little scientific value.
+- We serve **scientific applications** primarily, over “consumer or commercial” pose estimation tools. This often means
+prioritizing core functionality support, and rejecting implementations of “flashy” features that have little
+scientific value.
-- We are **domain agnostic** within the sciences. Functionality that is highly specific to particular scientific domains belongs in plugins, whereas functionality that cuts across many domains and is likely to be widely used belongs inside DeepLabCut.
+- We are **domain agnostic** within the sciences. Functionality that is highly specific to particular scientific
+domains belongs in plugins, whereas functionality that cuts across many domains and is likely to be widely used belongs
+inside DeepLabCut.
-- We value **education and documentation**. All functions should have docstrings, preferably with examples, and major functionality should be explained in our [tutorials](http://deeplabcut.org). Core developers can take an active role in finishing documentation examples.
+- We value **education and documentation**. All functions should have docstrings, preferably with examples, and major
+functionality should be explained in our [tutorials](http://deeplabcut.org). Core developers can take an active role
+in finishing documentation examples.
## Acknowledgements
-We share a lot of our mission and values with [`napari`](https://napari.org/docs/developers/MISSION_AND_VALUES.html) and [`scikit-image`](https://scikit-image.org/docs/dev/values.html) and acknowledge the influence of their mission and values statements on this document.
+We share a lot of our mission and values with [`napari`](https://napari.org/stable/community/mission_and_values.html)
+and [`scikit-image`](https://scikit-image.org/docs/stable/about/values.html) and acknowledge the influence of their
+mission and values statements on this document.
diff --git a/docs/ModelZoo.md b/docs/ModelZoo.md
index 82ebacfec5..ca5d790659 100644
--- a/docs/ModelZoo.md
+++ b/docs/ModelZoo.md
@@ -5,12 +5,14 @@
## 🏠 [Home page](http://modelzoo.deeplabcut.org/)
-Started in 2020, expanded in 2022 with PhD student [Shaokai Ye et al.](https://arxiv.org/abs/2203.07436v1), and the first proper [SuperAnimal Foundation Models]() published in 2024 🔥, the Model Zoo is four things:
+
+Started in 2020, expanded in 2022 with PhD student [Shaokai Ye et al.](https://arxiv.org/abs/2203.07436v1), and the
+first proper [SuperAnimal Foundation Models](#about-the-superanimal-models) published in 2024 🔥, the Model Zoo is four things:
- (1) a collection of models that are trained on diverse data across (typically) large datasets, which means you do not need to train models yourself, rather you can use them in your research applications.
- (2) a contribution website for community crowd sourcing of expertly labeled keypoints to improve models! You can get involved here: [contrib.deeplabcut.org](https://contrib.deeplabcut.org/).
- (3) a no-install DeepLabCut that you can use on ♾[Google Colab](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_DEMO_SuperAnimal.ipynb),
-test our models in 🕸[the browser](https://contrib.deeplabcut.org/), or on our 🤗[HuggingFace](https://huggingface.co/spaces/DeepLabCut/MegaDetector_DeepLabCut) app!
+test our models in 🕸[the browser](https://contrib.deeplabcut.org/), or on our 🤗[HuggingFace](https://huggingface.co/spaces/DeepLabCut/DeepLabCutModelZoo-SuperAnimals) app!
- (4) new methods to make SuperAnimal Foundation Models that combine data across different labs/datasets, keypoints, animals/species, and use on your data!
## Quick Start:
@@ -28,16 +30,18 @@ To provide the community with easy access to such high performance models across
- Models are based on what they are trained on, for example `superanimal_quadruped_x` is trained on [SuperAnimal-Quadruped-80K](https://zenodo.org/records/10619173). Each model class is described below:
+
### SuperAnimal-Quadruped:
+
- `superanimal_quadruped_x` models aim to work across a large range of quadruped animals, from horses, dogs, sheep, rodents, to elephants. The camera perspective is orthogonal to the animal ("side view"), and most of the data includes the animals face (thus the front and side of the animal). You will note we have several variants that differ in speed vs. performance, so please do test them out on your data to see which is best suited for your application. Also note we have a "video adaptation" feature, which lets you adapt your data to the model in a self-supervised way. No labeling needed!
-- [PLEASE SEE THE FULL DATASHEET HERE](https://zenodo.org/records/10619173)
-- [MORE DETAILS ON THE MODELS (detector, pose estimators)](https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped)
+- [Please see the full datasheet here](https://zenodo.org/records/10619173)
+- [More details on the models (detector, pose estimators)](https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-Quadruped)
- We provide several models:
- `superanimal_quadruped_hrnetw32` (pytorch engine)
- `superanimal_quadruped_hrnetw32` is a top-down model that is paired with a detector. That means it takes a cropped image from an object detector and predicts the keypoints. The object detector is currently a trained [ResNet50-based Faster-RCNN](https://pytorch.org/vision/stable/models/faster_rcnn.html).
- `superanimal_quadruped_dlcrnet` (tensorflow engine)
- - `superanimal_quadruped_dlcrnet` is a bottom-up model that predicts all keypoints then groups them into individuals. This can be faster, but more error prone.
+ - `superanimal_quadruped_dlcrnet` is a bottom-up model that predicts all keypoints, then groups them into individuals. This can be faster, but more error prone.
- `superanimal_quadruped` -> This is the same as `superanimal_quadruped_dlcrnet`, this was the old naming and being depreciated.
- For all models, they are automatically downloaded to modelzoo/checkpoints when used.
@@ -45,11 +49,13 @@ To provide the community with easy access to such high performance models across

+
### SuperAnimal-TopViewMouse:
+
- `superanimal_topviewmouse_x` aims to work across lab mice in different lab settings from a top-view perspective; this is very polar in many behavioral assays in freely moving mice.
-- [PLEASE SEE THE FULL DATASHEET HERE](https://zenodo.org/records/10618947)
-- [MORE DETAILS ON THE MODELS (detector, pose estimators)](https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse)
+- [Please see the full datasheet here](https://zenodo.org/records/10618947)
+- [More details on the models (detector, pose estimators)](https://huggingface.co/mwmathis/DeepLabCutModelZoo-SuperAnimal-TopViewMouse)
- We provide several models:
- `superanimal_topviewmouse_hrnetw32` (pytorch engine)
- `superanimal_topviewmouse_hrnetw32` is a top-down model that is paired with a detector. That means it takes a cropped image from an object detector and predicts the keypoints. The object detector is currently a trained [ResNet50-based Faster-RCNN](https://pytorch.org/vision/stable/models/faster_rcnn.html).
@@ -69,11 +75,14 @@ You can simply call the model and run video inference.
To note, a good step is typically to use our self-supervised video adaptation method to reduce jitter. In the `deeplabcut.video_inference_superanimal` simply function set the `video_adapt` option to __True__. Be aware, that enabling this option will (minimally) extend the processing time.
```python
-video_path = 'demo-video.mp4'
-superanimal_name = 'superanimal_quadruped_hrnetw32'
+import deeplabcut
+video_path = "demo-video.mp4"
+superanimal_name = "superanimal_quadruped"
deeplabcut.video_inference_superanimal([video_path],
superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
video_adapt = False)
```
@@ -83,13 +92,19 @@ deeplabcut.video_inference_superanimal([video_path],
In our work we introduced a spatial-pyramid for smartly rescaling images. Imagine if you frames are much larger than what we trained on, it would be hard for the model to find the animal! Here, you can simply guide the model with the `scale_list`:
```python
-video_path = 'demo-video.mp4'
-superanimal_name = 'superanimal_quadruped_dlcrnet'
+import deeplabcut
+video_path = "demo-video.mp4"
+superanimal_name = "superanimal_quadruped"
# The purpose of the scale list is to aggregate predictions from various image sizes. We anticipate the appearance size of the animal in the images to be approximately 400 pixels.
scale_list = range(200, 600, 50)
-deeplabcut.video_inference_superanimal([video_path], superanimal_name, scale_list=scale_list, video_adapt = False)
+deeplabcut.video_inference_superanimal([video_path],
+ superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ scale_list=scale_list,
+ video_adapt = False)
```
#### Practical example: Using transfer learning with superanimal weights.
@@ -100,37 +115,56 @@ Specifically:
* `superanimal_topviewmouse_x` uses 27 keypoints
```python
-superanimal_name = "superanimal_topviewmouse_hrnetw32"
+import os
+import deeplabcut
+from deeplabcut.modelzoo import build_weight_init
+
+superanimal_name = "superanimal_topviewmouse"
+
config_path = os.path.join(os.getcwd(), "openfield-Pranav-2018-10-30", "config.yaml")
-deeplabcut.create_training_dataset(config_path, superanimal_name = superanimal_name)
+weight_init = build_weight_init(
+ cfg=config_path,
+ super_animal=superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ with_decoder=False,
+)
+
+deeplabcut.create_training_dataset(config_path, weight_init = weight_init)
deeplabcut.train_network(config_path,
- maxiters=10,
+ epochs=10,
superanimal_name = superanimal_name,
superanimal_transfer_learning = True)
```
-
-
-
### Potential failure modes for SuperAnimal Models and how to fix it.
-Spatial domain shift: typical DNN models suffer from the spatial resolution shift between training datasets and test videos. To help find the proper resolution for our model, please try a range of `scale_list` in the API (details in the API docs). For `superanimal_quadruped`, we empirically observe that if your video is larger than 1500 pixels, it is better to pass `scale_list` in the range within 1000.
+Spatial domain shift: typical DNN models suffer from the spatial resolution shift between training datasets and test
+videos. To help find the proper resolution for our model, please try a range of `scale_list` in the API (details in the
+API docs). For `superanimal_quadruped`, we empirically observe that if your video is larger than 1500 pixels, it is
+better to pass `scale_list` in the range within 1000.
-Pixel statistics domain shift: The brightness of your video might look very different from our training datasets. This might either result in jittering predictions in the video or fail modes for lab mice videos (if the brightness of the mice is unusual compared to our training dataset). You can use our "video adaptation" model (released soon) to counter this.
+Pixel statistics domain shift: The brightness of your video might look very different from our training datasets.
+This might either result in jittering predictions in the video or fail modes for lab mice videos (if the brightness of
+the mice is unusual compared to our training dataset). You can use our "video adaptation" model to counter this.
### Our longer term perspective ...
-Via DeepLabCut Model Zoo, we aim to provide plug and play models that do not need any labeling and will just work decently on novel videos. If the predictions are not great enough due to failure modes described below, please give us feedback! We are rapidly improving our models and adaptation methods. We will also continue to expand this project to new model/data classes. Please do get in touch is you have data or ideas: modelzoo@deeplabcut.org
+Via DeepLabCut Model Zoo, we aim to provide plug and play models that do not need any labeling and will just work
+decently on novel videos. If the predictions are not great enough due to failure modes described below, please give us
+feedback! We are rapidly improving our models and adaptation methods. We will also continue to expand this project to
+new model/data classes. Please do get in touch is you have data or ideas: modelzoo@deeplabcut.org
## Publication:
-To see the first preprint on the work, click [here](https://arxiv.org/abs/2203.07436v1).
+To see the first preprint on the work, click [here](https://arxiv.org/abs/2203.07436v1).
-Our first publication on this project is now published at Nature Communications:
+Our first [publication](https://www.nature.com/articles/s41467-024-48792-2) on this project is now published at Nature
+Communications:
```{hint}
Here is the citation:
diff --git a/docs/Overviewof3D.md b/docs/Overviewof3D.md
index f7c8b63fa8..989ea484e9 100644
--- a/docs/Overviewof3D.md
+++ b/docs/Overviewof3D.md
@@ -1,14 +1,21 @@
(3D-overview)=
# 3D DeepLabCut
-In this repo we directly support 2-camera based 3D pose estimation. If you want n camera support, plus nicer optimization methods, please see our work that was published at [ICRA 2021 on strong baseline 3D models (and a 3D dataset)](https://github.com/African-Robotics-Unit/AcinoSet). In the link you will find how we optimize 6+ camera DLC output data for cheetahs (and see more below).
+In this repo we directly support 2-camera based 3D pose estimation. If you want n camera support, plus nicer
+optimization methods, please see our work that was published at
+[ICRA 2021 on strong baseline 3D models (and a 3D dataset)](https://github.com/African-Robotics-Unit/AcinoSet). In the
+link you will find how we optimize 6+ camera DLC output data for cheetahs (and see more below).
## **ATTENTION: Our code base in this repo assumes you:**
-A. You have 2D videos and a DeepLabCut network to analyze them as described in the [main documentation](overview). This can be with multiple separate networks for each camera (less recommended), or one network trained on all views - recommended! (See [Nath*, Mathis* et al., 2019](https://www.biorxiv.org/content/10.1101/476531v1)). We also support multi-animal 3D with this code (please see [Lauer et al. 2022](https://doi.org/10.1038/s41592-022-01443-0)).
+A. You have 2D videos and a DeepLabCut network to analyze them as described in the
+[main documentation](overview). This can be with multiple
+separate networks for each camera (less recommended), or one network trained on all views - recommended! (See
+[Nath*, Mathis* et al., 2019](https://www.biorxiv.org/content/10.1101/476531v1)). We also support multi-animal 3D with this code (please see
+[Lauer et al. 2022](https://doi.org/10.1038/s41592-022-01443-0)).
B. You are using 2 cameras, in a [stereo configuration](https://github.com/DeepLabCut/DeepLabCut/blob/5ac4c8cb6bcf2314a3abfcf979b8dd170608e094/deeplabcut/pose_estimation_3d/camera_calibration.py#L223), for 3D*.
@@ -20,12 +27,16 @@ Here are other excellent options for you to use that extend DeepLabCut:
-- **[AcinoSet](https://github.com/African-Robotics-Unit/AcinoSet)**; **n**-camera support with triangulation, extended Kalman filtering, and trajectory optimization code (see video to the right for a min demo, courtesy of Prof. Patel), plus a GUI to visualize 3D data. It is built to work directly with DeepLabCut (but currently tailored to cheetah's, thus some coding skills are required at this time).
+- **[AcinoSet](https://github.com/African-Robotics-Unit/AcinoSet)**; **n**-camera support with triangulation, extended Kalman filtering, and trajectory optimization
+code (see video to the right for a min demo, courtesy of Prof. Patel), plus a GUI to visualize 3D data. It is built to
+work directly with DeepLabCut (but currently tailored to cheetah's, thus some coding skills are required at this time).
-- **[anipose.org](https://anipose.readthedocs.io/en/latest/)**; a wrapper for 3D deeplabcut that provides >3 camera support and is built to work directly with DeepLabCut. You can `pip install anipose` into your DLC conda environment.
+- **[anipose.org](https://anipose.readthedocs.io/en/latest/)**; a wrapper for 3D deeplabcut that provides >3 camera support and is built to work directly with
+DeepLabCut. You can `pip install anipose` into your DLC conda environment.
-- **Argus, easywand or DLTdv** w/DeepLabCut see https://github.com/haliaetus13/DLCconverterDLT; this can be used with the the highly popular Argus or DLTdv tools for wand calibration.
+- **Argus, easywand or DLTdv** w/DeepLabCut see https://github.com/haliaetus13/DLCconverterDLT; this can be used with
+the the highly popular Argus or DLTdv tools for wand calibration.
## Jump in with direct DeepLabCut 2-camera support:
@@ -39,83 +50,118 @@ Here are other excellent options for you to use that extend DeepLabCut:
### (1) Create a New 3D Project:
-Watch a [DEMO VIDEO](https://youtu.be/Eh6oIGE4dwI) on how to use this code, and check out the Notebook [here](https://github.com/DeepLabCut/DeepLabCut/blob/master/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb)!
+Watch a [DEMO VIDEO](https://youtu.be/Eh6oIGE4dwI) on how to use this code, and check out the Notebook [here](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb)!
-You will run this function **one** time per project; a project is defined as a given set of cameras and calibration images. You can always analyze new videos within this project.
+You will run this function **one** time per project; a project is defined as a given set of cameras and calibration
+images. You can always analyze new videos within this project.
-The function **create\_new\_project\_3d** creates a new project directory specifically for converting the 2D pose to 3D pose, required subdirectories, and a basic 3D project configuration file. Each project is identified by the name of the project (e.g. Task1), name of the experimenter (e.g. YourName), as well as the date at creation.
+The function **create\_new\_project\_3d** creates a new project directory specifically for converting the 2D pose to 3D
+pose, required subdirectories, and a basic 3D project configuration file. Each project is identified by the name of the
+project (e.g. Task1), name of the experimenter (e.g. YourName), as well as the date at creation.
-Thus, this function requires the user to input the enter the name of the project, the name of the experimenter and number of cameras to be used. Currently, DeepLabCut supports triangulation using 2 cameras, but will expand to more than 2 cameras in a future version.
+Thus, this function requires the user to enter the name of the project, the name of the experimenter and number of
+cameras to be used. Currently, DeepLabCut supports triangulation using 2 cameras, but will expand to more than 2 cameras
+in a future version.
To start a 3D project type the following in ipython:
```python
-deeplabcut.create_new_project_3d('ProjectName','NameofLabeler',num_cameras = 2)
+deeplabcut.create_new_project_3d("ProjectName", "NameofLabeler", num_cameras=2)
```
-TIP 1: you can also pass ``working_directory=`Full path of the working directory'`` if you want to place this folder somewhere beside the current directory you are working in. If the optional argument ``working_directory`` is unspecified, the project directory is created in the current working directory.
+TIP 1: you can also pass `working_directory="Full path of the working directory"` if you want to place this folder
+somewhere beside the current directory you are working in. If the optional argument `working_directory` is unspecified,
+the project directory is created in the current working directory.
-TIP 2: you can also place ``config_path3d`` in front of ``deeplabcut.create_new_project_3d`` to create a variable that holds the path to the config.yaml file, i.e. ``config_path3d=deeplabcut.create_new_project_3d(...`` Or, set this variable for easy use. Please note that ``config_path3d='Full path of the 3D project configuration file'``.
+TIP 2: you can also place `config_path3d` in front of `deeplabcut.create_new_project_3d` to create a variable that holds
+the path to the config.yaml file, i.e. `config_path3d=deeplabcut.create_new_project_3d(...` Or, set this variable for
+easy use. Please note that `config_path3d='Full path of the 3D project configuration file'`.
- This function will create a project directory with the name **Name of the project+name of the experimenter+date of creation of the project+3d** in the **Working directory**. The project directory will have subdirectories: **calibration_images**, **camera_matrix**, **corners**, and **undistortion**. All the outputs generated during the course of a project will be stored in one of these subdirectories, thus allowing each project to be curated in separation from other projects.
+This function will create a project directory with the name **Name of the project+name of the experimenter+date of
+creation of the project+3d** in the **Working directory**. The project directory will have subdirectories:
+**calibration_images**, **camera_matrix**, **corners**, and **undistortion**. All the outputs generated during the
+course of a project will be stored in one of these subdirectories, thus allowing each project to be curated in
+separation from other projects.
- The purpose of the subdirectories is as follows:
+The purpose of the subdirectories is as follows:
- **calibration_images:** This directory will contain a set of calibration images acquired from the two cameras. A calibration image can be acquired using a printed checkerboard and its pair wise images are taken from both the cameras to consider as a set of calibration images. These pair of images are saved as ``.jpg`` with camera names as the prefix. e.g. ``camera-1-01.jpg`` and ``camera-2-01.jpg`` for the first pair of images. While taking the images:
-- Keep the orientation of the chessboard same and do not rotate more than 30 degrees. Rotating the chessboard circular will change the origin across the frames and may result in incorrect order of detected corners.
-- Cover several distances, and within each distance, cover all parts of the image view (all corners and center).
-Use a chessboard as big as possible, ideally a chessboard with of at least 8x6 squares.
-- Aim for taking at least 70 pair of images as after corner detection, some of the images might need to be discarded due to either incorrect corner detection or incorrect order of detected corners.
+**calibration_images:** This directory will contain a set of calibration images acquired from the two cameras. A
+calibration image can be acquired using a printed checkerboard and its pair wise images are taken from both the cameras
+to consider as a set of calibration images.
- **camera_matrix:** This directory will store the parameter for both the cameras as a pickle file. Specifically, these pickle files contain the intrinsic and extrinsic camera parameters. While the intrinsic parameters represent a transformation from 3-D camera's coordinates into the image coordinates, the extrinsic parameters represent a rigid transformation from world coordinate system to the 3-D camera's coordinate system.
+**camera_matrix:** This directory will store the parameter for both the cameras as a pickle file. Specifically, these
+pickle files contain the intrinsic and extrinsic camera parameters. While the intrinsic parameters represent a
+transformation from 3-D camera's coordinates into the image coordinates, the extrinsic parameters represent a rigid
+transformation from world coordinate system to the 3-D camera's coordinate system.
- **corners:** As a part of camera calibration, the checkerboard pattern is detected in the calibration images and these patterns will be stored in this directory. Each row of the checkerboard grid is marked with a unique color.
+**corners:** As a part of camera calibration, the checkerboard pattern is detected in the calibration images and these
+patterns will be stored in this directory. Each row of the checkerboard grid is marked with a unique color.
- **undistortion:** In order to check for calibration, the calibration images and the corresponding corner points are undistorted. These undistorted images are overlaid with undistorted points and will be stored in this directory.
+**undistortion:** In order to check for calibration, the calibration images and the corresponding corner points are
+undistorted. These undistorted images are overlaid with undistorted points and will be stored in this directory.
- Here is an overview of the calibration and triangulation workflow that follows:
+Here is an overview of the calibration and triangulation workflow that follows:
-
+
### (2) Take and Process Camera Calibration Images:
- (**CRITICAL!**) You must take images of a checkerboard to calibrate your images. Here are example boards you could print and use (mount it on a flat, hard surface!): https://markhedleyjones.com/projects/calibration-checkerboard-collection.
+(**CRITICAL!**) You must take images of a checkerboard to calibrate your images. Here are example boards you could
+print and use (mount it on a flat, hard surface!):
+https://markhedleyjones.com/projects/calibration-checkerboard-collection.
- You must save the image pairs as .jpg files.
-- They should be named with the **camera-#** as the prefix, i.e. **camera-1-01.jpg** and **camera-2-01.jpg** for the first pair of images. Please note, this cannot be changed after the project is created.
+- They should be named with the **camera-#** as the prefix, i.e. **camera-1-01.jpg** and **camera-2-01.jpg** for the
+first pair of images. Please note, this cannot be changed after the project is created.
-**TIP:** If you want to take a short video (vs. snapping pairs of frames) while you move the checkerboard around, you can use this command inside your conda environment (but outside of ipython!) to convert the video to **.jpg** frames (this will take the first 20 frames (set with ``-vframes``) and name them camera-1-001.jpg, etc; edit appropriately):
+**TIP:** If you want to take a short video (vs. snapping pairs of frames) while you move the checkerboard around, you
+can use this command inside your conda environment (but outside of ipython!) to convert the video to **.jpg** frames
+(this will take the first 20 frames (set with `-vframes`) and name them camera-1-001.jpg, etc; edit appropriately):
```python
ffmpeg -i videoname.mp4 -vframes 20 camera-1-%03d.jpg
```
- While taking the images:
- - Keep the orientation of the checkerboard the same and do not rotate it more than 30 degrees. Rotating the checkerboard circular will change the origin across the frames and may result in incorrect order of detected corners.
+ - Keep the orientation of the checkerboard the same and do not rotate it more than 30 degrees. Rotating the
+ checkerboard circular will change the origin across the frames and may result in incorrect order of detected corners.
- - Cover several distances, and within each distance, cover all parts of the image view (all corners and center).
+ - Cover several distances, and within each distance, cover all parts of the image view (all corners and center).
- - Use a checkerboard as big as possible, ideally with at least 8x6 squares.
+ - Use a checkerboard as big as possible, ideally with at least 8x6 squares.
- - Aim for taking at least 30-70 pair of images, as after corner detection, some of the images might need to be discarded due to either incorrect corner detection or incorrect order of detected corners.
+ - Aim for taking at least 30-70 pair of images, as after corner detection, some of the images might need to be
+ discarded due to either incorrect corner detection or incorrect order of detected corners.
- - You can take the images as a series of .jpg images, or a video where you post-hoc pair sync'd frames (see tip above).
+ - You can take the images as a series of .jpg images, or a video where you post-hoc pair sync'd frames (see tip
+ above).
-The camera calibration is an **iterative process**, where the user needs to select a set of calibration images where the grid pattern is correctly detected. The function:``deeplabcut.calibrate_cameras(config_path)``
-extracts the grid pattern from the calibration images and store them under the `corners` directory. The grid pattern could be 8x8 or 5x5 etc. We use a pattern of the 8x6 grid to find the internal corners of the checkerboard.
+The camera calibration is an **iterative process**, where the user needs to select a set of calibration images where the
+grid pattern is correctly detected. The function `deeplabcut.calibrate_cameras(config_path)`
+extracts the grid pattern from the calibration images and store them under the `corners` directory. The grid pattern
+could be 8x8 or 5x5 etc. We use a pattern of the 8x6 grid to find the internal corners of the checkerboard.
-In some cases, it may happen that the corners are not detected correctly or the order of corners detected in the camera-1 image and camera-2 image is incorrect. You need to remove these pair of images from the **calibration_images** folder as they will reduce the calibration accuracy.
+In some cases, it may happen that the corners are not detected correctly or the order of corners detected in the
+camera-1 image and camera-2 image is incorrect. You need to remove these pair of images from the **calibration_images**
+folder as they will reduce the calibration accuracy.
To begin, please place your images into the **calibration_images** directory.
- (**CRITICAL!**) Edit the **config.yaml** file to set the camera names; note that once this is set, **do not change the names!**
+(**CRITICAL!**) Edit the **config.yaml** file to set the camera names; note that once this is set, **do not change the
+names!**
Then, run:
```python
deeplabcut.calibrate_cameras(config_path3d, cbrow=8, cbcol=6, calibrate=False, alpha=0.9)
```
-NOTE: you need to specify how many rows (``cbrow``) and columns (``cbcol``) your checkerboard has. Also, first set the variable ``calibrate`` to **False**, so you can remove any faulty images. You need to visually inspect the output to check for the detected corners and select those pair of images where the corners are correctly detected. Please note, If the scaling parameter ``alpha=0``, it returns undistorted image with minimum unwanted pixels. So it may even remove some pixels at image corners. If ``alpha=1``, all pixels are retained with some extra black images.
+
+NOTE: you need to specify how many rows (`cbrow`) and columns (`cbcol`) your checkerboard has (beware, we count
+edges between squares and not squares themselves, so for a 8 x 8 squares checkerboard set `cbrow=7` and `cbcol=7`).
+Also, first set the variable `calibrate` to **False**, so you can remove any faulty images. You need to visually
+inspect the output to check for the detected corners and select those pair of images where the corners are correctly
+detected. Please note, If the scaling parameter `alpha=0`, it returns undistorted image with minimum unwanted pixels.
+So it may even remove some pixels at image corners. If `alpha=1`, all pixels are retained with some extra black images.
Here is what they might look like:
@@ -125,55 +171,85 @@ Here is what they might look like:
-Once all the set of images are selected (namely, delete from the folder any bad pairs!) where the corners and their orders are detected correctly, then the two cameras can be calibrated using:
+Once all the set of images has been selected (namely, delete from the folder any bad pairs!) where the corners and their
+orders are detected correctly, then the two cameras can be calibrated using:
```python
deeplabcut.calibrate_cameras(config_path3d, cbrow=8, cbcol=6, calibrate=True, alpha=0.9)
```
-This computes the intrinsic and extrinsic parameters for each camera. A re-projection error is also computed using the intrinsic and extrinsic parameters which provide an estimate of how good the parameters are. The transformation between the two cameras are estimated and the cameras are stereo calibrated. Furthermore, the above function brings both the camera image plane to the same plane by computing the stereo rectification. These parameters are stored as a pickle file named as `stereo_params.pickle` under the directory `camera_matrix`.
+This computes the intrinsic and extrinsic parameters for each camera. A re-projection error is also computed using the
+intrinsic and extrinsic parameters which provide an estimate of how good the parameters are. The transformation between
+the two cameras is estimated and the cameras are stereo calibrated. Furthermore, the above function brings both the
+camera image plane to the same plane by computing the stereo rectification. These parameters are stored as a pickle file
+named as `stereo_params.pickle` under the directory `camera_matrix`.
-Once you have run this for the project, you do not need to do so again (unless you want to re-calibrate your cameras); be advised, if you do re-calibrate, you may want to clearly mark which videos are analyzed with "old" vs. "new" calibration images.
+Once you have run this for the project, you do not need to do so again (unless you want to re-calibrate your cameras);
+be advised, if you do re-calibrate, you may want to clearly mark which videos are analyzed with "old" vs. "new"
+calibration images.
### (3) Check for Undistortion:
-In order to check how well the stereo calibration is, it is recommended to undistort the calibration images and the corner points using camera matrices and project these undistorted points on the undistorted images to check if they align correctly. This can be done in deeplabcut as:
+In order to check how well the stereo calibration is, it is recommended to undistort the calibration images and the
+corner points using camera matrices and project these undistorted points on the undistorted images to check if they
+align correctly. This can be done in deeplabcut as:
```python
deeplabcut.check_undistortion(config_path3d, cbrow=8, cbcol=6)
```
-Each calibration image is undistorted and saved under the directory ``undistortion``. A plot with a pair of undistorted camera images with its undistorted corner points overlaid is also stored. Please visually inspect this image. All the undistorted corner points from all the calibration images are triangulated and plotted for the user to visualize for any undistortion related errors. If they are not correct, go check and revise the calibration images (then repeat the calibration and this step)!
+Each calibration image is undistorted and saved under the directory `undistortion`. A plot with a pair of undistorted
+camera images with its undistorted corner points overlaid is also stored. Please visually inspect this image. All the
+undistorted corner points from all the calibration images are triangulated and plotted for the user to visualize for any
+undistortion related errors. If they are not correct, go check and revise the calibration images (then repeat the
+calibration and this step)!
### (4) Triangulation --> Take your 2D to 3D!
-If there are no errors in the undistortion, then the pose from the 2 cameras can be triangulated to get the 3D DeepLabCut coordinates!
+If there are no errors in the undistortion, then the pose from the 2 cameras can be triangulated to get the 3D
+DeepLabCut coordinates!
- (**CRITICAL!**) Name the video files in such a way that the file name **contains the name of the cameras** as specified in the ``config file``. e.g. if the cameras as named as ``camera-1`` and ``camera-2`` (or ``cam-1``, ``cam-2`` etc.) then the video filename must contain this naming, i.e. this could be named as ``rig-1-mouse-day1-camera-1.avi`` and ``rig-1-mouse-day1-camera-2.avi`` or could be ``rig-1-mouse-day1-camera-1-date.avi`` and ``rig-1-mouse-day1-camera-2-date.avi``.
+(**CRITICAL!**) Name the video files in such a way that the file name **contains the name of the cameras** as specified
+in the `config file`. e.g. if the cameras as named as `camera-1` and `camera-2` (or `cam-1`, `cam-2` etc.) then the
+video filename must contain this naming, i.e. this could be named as `rig-1-mouse-day1-camera-1.avi` and
+`rig-1-mouse-day1-camera-2.avi` or could be `rig-1-mouse-day1-camera-1-date.avi` and
+`rig-1-mouse-day1-camera-2-date.avi`.
- **Note** that to correctly pair the videos, the file names otherwise need to be the same!
- If helpful, [here is the software we use to record videos](https://github.com/AdaptiveMotorControlLab/Camera_Control).
-- **Note** that the videos do not need to be the same pixel size, but be sure they are similar in size to the calibration images (and they must be the same cameras used for calibration).
- (**CRITICAL!**) You must also edit the **3D project config.yaml** file to denote which DeepLabCut projects have the information for the 2D views.
+(**CRITICAL!**) You must also edit the **3D project config.yaml** file to denote which DeepLabCut projects have the
+information for the 2D views.
- - Of critical importance is that you need to input the **same** body part names as in the config.yaml file of the 2D project.
-- You must set the snapshot to use inside the 2D config file (default is -1, namely the last training snapshot of the network).
+- Of critical importance is that you need to input the **same** body part names as in the config.yaml file of the 2D
+project.
+- You must set the snapshot to use inside the 2D config file (default is -1, namely the last training snapshot of the
+network).
- You need to set a "scorer 3D" name; this will point to the project file and be set in future 3D output file names.
-- You should define a "skeleton" here as well (note, this is not rigid, it just connects the points in the plotting step). Not every point needs to be "skeletonized", i.e. these points can be a subset of the full body parts list. The other points will just be plotted into the 3D space. Here is how the config.yaml looks with some example inputs:
+- You should define a "skeleton" here as well (note, this is not rigid, it just connects the points in the plotting
+step). Not every point needs to be "skeletonized", i.e. these points can be a subset of the full body parts list. The
+other points will just be plotted into the 3D space. Here is how the config.yaml looks with some example inputs:
-(**CRITICAL!**) This step will also run the equivalent of ``analyze_videos`` (in 2D) for you and then apply a median filter to the 2D data (``filterpredictions=True`` is by default)! If you already ran the 2D analysis and there is a filtered output file, it will take this by default (otherwise it will take your unfiltered 2D analysis files)!
+(**CRITICAL!**) This step will also run the equivalent of `analyze_videos` (in 2D) for you and then apply a median
+filter to the 2D data (`filterpredictions=True`)! If you already ran the 2D analysis and there is a filtered output
+file, it will take this by default (otherwise it will take your unfiltered 2D analysis files)!
-Next, pass the ``config_path3d`` and now the video folder path, which is the path to the **folder** where all the videos from two cameras are stored. The triangulation can be done in deeplabcut by typing:
+Next, pass the `config_path3d` and now the video folder path, which is the path to the **folder** where all the videos
+from two cameras are stored. The triangulation can be done in deeplabcut by typing:
```python
-deeplabcut.triangulate(config_path3d, '/yourcomputer/fullpath/videofolder', filterpredictions=True/False)
+deeplabcut.triangulate(
+ config_path3d,
+ "/yourcomputer/fullpath/videofolder",
+ filterpredictions=True/False
+)
```
-NOTE: Windows users, you must input paths as: ``r`C:\Users\computername\videofolder' `` or ``C:\\Users\\computername\\videofolder'``.
+NOTE: Windows users, you must input paths as: ``r`C:\Users\computername\videofolder'`` or
+``C:\\Users\\computername\\videofolder'``.
**TIP:** Here are all the parameters you can pass:
@@ -204,8 +280,15 @@ destfolder: string, optional
save_as_csv: bool, optional
Saves the predictions in a .csv file. The default is ``False``; if provided it must be either ``True`` or ``False``
+
+track_method: str, optional
+ Method used for tracking: "box" or "ellipse"
```
-The **triangulated file** is now saved under the same directory where the video files reside (or the destination folder you set)! This can be used for future analysis. This step can be run at anytime as you collect new videos, and easily added to your automated analysis pipeline, i.e. such as **replacing** ``deeplabcut.triangulate(config_path3d, video_path)`` with ``deeplabcut.analyze_videos`` (as if it's not analyzed in 2D already, this function will take care of it ;):
+The **triangulated file** is now saved under the same directory where the video files reside (or the destination folder
+you set)! This can be used for future analysis. This step can be run at anytime as you collect new videos, and easily
+added to your automated analysis pipeline, i.e. such as **replacing**
+`deeplabcut.triangulate(config_path3d, video_path)` with `deeplabcut.analyze_videos` (as if it's not analyzed in 2D
+already, this function will take care of it ;):
@@ -213,22 +296,36 @@ The **triangulated file** is now saved under the same directory where the video
### (5) Visualize your 3D DeepLabCut Videos:
-In order to visualize both the 2D videos with tracked points plut the pose in 3D, the user can create a 3D video for certain frames (these are large files, so we advise just looking at a subset of frames). The user can specify the config file, the **path of the triangulated file folder**, and specify the start and end frame indices to create a 3D labeled video. Note that the ``triangulated_file_folder`` is where the newly created file that ends with ``yourDLC_3D_scorername.h5`` is located. This can be done using:
+In order to visualize both the 2D videos with tracked points plut the pose in 3D, the user can create a 3D video for
+certain frames (these are large files, so we advise just looking at a subset of frames). The user can specify the config
+file, the **path of the triangulated file folder**, and specify the start and end frame indices to create a 3D labeled
+video. Note that the `triangulated_file_folder` is where the newly created file that ends with
+`yourDLC_3D_scorername.h5` is located. This can be done using:
```python
-deeplabcut.create_labeled_video_3d(config_path, ['triangulated_file_folder'], start=50, end=250)
+deeplabcut.create_labeled_video_3d(
+ config_path,
+ ["triangulated_file_folder"],
+ start=50,
+ end=250
+)
```
-**TIP:** (see more parameters below) You can set how the axis of the 3D plot on the far right looks by changing the variables ``xlim``, ``ylim``, ``zlim`` and ``view``. Your checkerboard_3d.png image which was created above will show you the axis ranges. Here is an example:
+**TIP:** (see more parameters below) You can set how the axis of the 3D plot on the far right looks by changing the
+variables `xlim`, `ylim`, `zlim` and `view`. Your checkerboard_3d.png image which was created above will show you the
+axis ranges. Here is an example:
-``View`` is used to set the elevation and azimuth of the axes (defaults are [113, 270], and you should play around to find the view-point you like!). Also note that the video is created from a set of .png files in a "temp" directory, so as soon as you run this command you can open the first image, and if you don't like the view, hit ``CNTRL+C`` to stop, edit the values, and start again!
+`View` is used to set the elevation and azimuth of the axes (defaults are [113, 270], and you should play around to find
+the view-point you like!). Also note that the video is created from a set of .png files in a "temp" directory, so as
+soon as you run this command you can open the first image, and if you don't like the view, hit `CNTRL+C` to stop, edit
+the values, and start again!
**Other optional parameters include:**
-
+here
```python
videofolder: string
Full path of the folder where the videos are stored. Use this if the vidoes are stored in a different location other than where the triangulation files are stored. By default is ``None`` and therefore looks for video files in the directory where the triangulation file is stored.
@@ -251,8 +348,25 @@ ylim: list
zlim: list
A list of integers specifying the limits for zaxis of 3d view. By default it is set to [None,None], where the z limit is set by taking the minimum and maximum value of the z coordinates for all the bodyparts.
+
+draw_skeleton: bool
+ If True adds a line connecting the body parts making a skeleton on on each frame. The body parts to be connected and the color of these connecting lines are specified in the config file. By default: True
+
+color_by : string, optional (default='bodypart')
+ Coloring rule. By default, each bodypart is colored differently.
+ If set to 'individual', points belonging to a single individual are colored the same.
+
+figsize: tuple[int, int], optional, default=(80, 8)
+ Size of the figure
+
+fps: int, optional, default=30
+ Frames per second
+
+dpi: int, optional, default=300
+ Dots per inch (resplution)
```
### If you use this code:
-We kindly ask that you cite [Mathis et al, 2018](https://www.nature.com/articles/s41593-018-0209-y) **&** [Nath*, Mathis*, et al., 2019](https://doi.org/10.1038/s41596-019-0176-0). If you use 3D multi-animal: [Lauer et al. 2022](https://doi.org/10.1038/s41592-022-01443-0).
+We kindly ask that you cite [Mathis et al, 2018](https://www.nature.com/articles/s41593-018-0209-y) **&** [Nath*, Mathis*, et al., 2019](https://doi.org/10.1038/s41596-019-0176-0). If you use 3D
+multi-animal: [Lauer et al. 2022](https://doi.org/10.1038/s41592-022-01443-0).
diff --git a/docs/UseOverviewGuide.md b/docs/UseOverviewGuide.md
index 778b5d2c0e..7e7ec1a137 100644
--- a/docs/UseOverviewGuide.md
+++ b/docs/UseOverviewGuide.md
@@ -18,11 +18,11 @@ We are primarily a package that enables deep learning-based pose estimation. We
- Decide on your needs: there are **two main modes, standard DeepLabCut or multi-animal DeepLabCut**. We highly recommend carefully considering which one is best for your needs. For example, a white mouse + black mouse would call for standard, while two black mice would use multi-animal. **[Important Information on how to use DLC in different scenarios (single vs multi animal)](important-info-regd-usage)** Then pick a user guide:
-- (1) [How to use standard DeepLabCut](single-animal-userguide)
-- (2) [How to use multi-animal DeepLabCut](multi-animal-userguide)
+ - (1) [How to use standard DeepLabCut](single-animal-userguide)
+ - (2) [How to use multi-animal DeepLabCut](multi-animal-userguide)
- To note, as of DLC3+ the single and multi-animal code bases are more integrated and we support **top-down**, **bottom-up**, and a new "hybrid" approach that is state-of-the-art, called **BUCTD** (bottom-up conditional top down), models.
- - If these terms are new to you, check out our [Primer on Motion Capture with Deep Learning!](https://www.sciencedirect.com/science/article/pii/S0896627320307170). In brief, both work for single or multiple animals and each method can be better or worse on your data.
+ - If these terms are new to you, check out our [Primer on Motion Capture with Deep Learning!](https://www.sciencedirect.com/science/article/pii/S0896627320307170). In brief, both work for single or multiple animals and each method can be better or worse on your data.
@@ -67,15 +67,19 @@ Getting Started: [a video tutorial on navigating the documentation!](https://www
-
+
### Overview of the workflow:
This page contains a list of the essential functions of DeepLabCut as well as demos. There are many optional parameters with each described function, which you can find [here](functionDetails.md). For additional assistance, you can use the [help](UseOverviewGuide.md#help) function to better understand what each function does.
-
-
-
+
+
+
+
+ View in full screen
+
+
You can have as many projects on your computer as you wish. You can have DeepLabCut installed in an [environment](/conda-environments) and always exit and return to this environment to run the code. You just need to point to the correct ``config.yaml`` file to [jump back in](/docs/UseOverviewGuide.md#tips-for-daily-use)! The documentation below will take you through the individual steps.
@@ -136,7 +140,7 @@ with the terminal interface you get the most versatility and options.
## Option 1: Demo Notebooks:
[VIDEO TUTORIAL AVAILABLE!](https://www.youtube.com/watch?v=DRT-Cq2vdWs)
-We provide Jupyter and COLAB notebooks for using DeepLabCut on both a pre-labeled dataset, and on the end user’s
+We provide Jupyter and COLAB notebooks for using DeepLabCut on both a pre-labeled dataset, and on the end user's
own dataset. See all the demo's [here!](/examples) Please note that GUIs are not easily supported in Jupyter in MacOS, as you need a framework build of python. While it's possible to launch them with a few tweaks, we recommend using the Project Manager GUI or terminal, so please follow the instructions below.
(using-project-manager-gui)=
diff --git a/docs/api/deeplabcut.analyze_videos.rst b/docs/api/deeplabcut.analyze_videos.rst
index d43e39ea4c..274a801f02 100644
--- a/docs/api/deeplabcut.analyze_videos.rst
+++ b/docs/api/deeplabcut.analyze_videos.rst
@@ -1 +1 @@
-.. autofunction:: deeplabcut.pose_estimation_tensorflow.predict_videos.analyze_videos
+.. autofunction:: deeplabcut.compat.analyze_videos
diff --git a/docs/api/deeplabcut.convert_detections2tracklets.rst b/docs/api/deeplabcut.convert_detections2tracklets.rst
new file mode 100644
index 0000000000..f69f721d84
--- /dev/null
+++ b/docs/api/deeplabcut.convert_detections2tracklets.rst
@@ -0,0 +1 @@
+.. autofunction:: deeplabcut.compat.convert_detections2tracklets
diff --git a/docs/api/deeplabcut.create_training_dataset_from_existing_split.rst b/docs/api/deeplabcut.create_training_dataset_from_existing_split.rst
new file mode 100644
index 0000000000..0b59472d91
--- /dev/null
+++ b/docs/api/deeplabcut.create_training_dataset_from_existing_split.rst
@@ -0,0 +1 @@
+.. autofunction:: deeplabcut.generate_training_dataset.trainingsetmanipulation.create_training_dataset_from_existing_split
diff --git a/docs/api/deeplabcut.evaluate_network.rst b/docs/api/deeplabcut.evaluate_network.rst
index f24ee4c481..56914774fe 100644
--- a/docs/api/deeplabcut.evaluate_network.rst
+++ b/docs/api/deeplabcut.evaluate_network.rst
@@ -1 +1 @@
-.. autofunction:: deeplabcut.pose_estimation_tensorflow.core.evaluate.evaluate_network
+.. autofunction:: deeplabcut.compat.evaluate_network
diff --git a/docs/api/deeplabcut.label_frames.rst b/docs/api/deeplabcut.label_frames.rst
index b1a810d284..4de3a1054c 100644
--- a/docs/api/deeplabcut.label_frames.rst
+++ b/docs/api/deeplabcut.label_frames.rst
@@ -1 +1 @@
-.. autofunction:: deeplabcut.gui.label_frames.label_frames
+.. autofunction:: deeplabcut.gui.tabs.label_frames.label_frames
diff --git a/docs/api/deeplabcut.refine_labels.rst b/docs/api/deeplabcut.refine_labels.rst
index 7c61d5586c..b54b640e46 100644
--- a/docs/api/deeplabcut.refine_labels.rst
+++ b/docs/api/deeplabcut.refine_labels.rst
@@ -1 +1 @@
-.. autofunction:: deeplabcut.gui.refine_labels.refine_labels
+.. autofunction:: deeplabcut.gui.tabs.label_frames.refine_labels
diff --git a/docs/api/deeplabcut.stitch_tracklets.rst b/docs/api/deeplabcut.stitch_tracklets.rst
new file mode 100644
index 0000000000..96677d31a9
--- /dev/null
+++ b/docs/api/deeplabcut.stitch_tracklets.rst
@@ -0,0 +1 @@
+.. autofunction:: deeplabcut.refine_training_dataset.stitch.stitch_tracklets
diff --git a/docs/api/deeplabcut.train_network.rst b/docs/api/deeplabcut.train_network.rst
index e724591d21..cd32c85295 100644
--- a/docs/api/deeplabcut.train_network.rst
+++ b/docs/api/deeplabcut.train_network.rst
@@ -1 +1 @@
-.. autofunction:: deeplabcut.pose_estimation_tensorflow.training.train_network
+.. autofunction:: deeplabcut.compat.train_network
diff --git a/docs/beginner-guides/Training-Evaluation.md b/docs/beginner-guides/Training-Evaluation.md
index de67d92198..18ab3bf736 100644
--- a/docs/beginner-guides/Training-Evaluation.md
+++ b/docs/beginner-guides/Training-Evaluation.md
@@ -40,7 +40,9 @@ After training, it's time to see how well your model performs.
### Understanding the Evaluation Results
-- **Performance Metrics:** DLC will assess the latest snapshot of your model, generating a `.CSV` file with performance metrics. This file is stored in the **`evaluate network`** folder within your project.
+- **Performance Metrics:** DLC will assess the latest snapshot of your model, generating a `.CSV` file with performance
+metrics. This file is stored in the **`evaluation-results`** (for TensorFlow models) or the
+**`evaluation-results-pytorch`** (for PyTorch models) folder within your project.
)
diff --git a/docs/beginner-guides/beginners-guide.md b/docs/beginner-guides/beginners-guide.md
index 0c8d323647..eeef2d8c4f 100644
--- a/docs/beginner-guides/beginners-guide.md
+++ b/docs/beginner-guides/beginners-guide.md
@@ -1,3 +1,4 @@
+(beginners-guide)=
# Using DeepLabCut
@@ -38,13 +39,13 @@ Now, we are going to install the core dependencies. The way this works is that t
`PyTorch` is the backend deep-learning language we wrote DLC3 in. To select the right version, head to the [“Install PyTorch”](https://pytorch.org/get-started/locally/) instructions in the official PyTorch Docs. Select your desired PyTorch build, operating system, select conda as your package manager and Python as the language. Select your compute platform (either a CUDA version or CPU only). Then, use the command to install the PyTorch package. Below are a few possible examples:
-- **GPU version of pytorch for CUDA 11.8**
+- **GPU version of pytorch for CUDA 12.4**
```
-conda install pytorch cudatoolkit=11.8 -c pytorch
+pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
```
- **CPU only version of pytorch, using the latest version**
```
-conda install pytorch cpuonly -c pytorch
+pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
```
**(2) Install DeepLabCut**
@@ -61,7 +62,7 @@ pip install "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=de
```
- OR run for the **Stable release:**
```
-pip install "deeplabcut[gui,modelzoo,tf]"
+pip install "deeplabcut[gui,modelzoo,wandb]"
```
- This gives you DeepLabCut, the DLC GUI (gui), our latest neural networks (modelzoo) and a cool data logger (wandb) if you choose to use it later on!
diff --git a/docs/beginner-guides/labeling.md b/docs/beginner-guides/labeling.md
index 7d944137a6..e5c7492722 100644
--- a/docs/beginner-guides/labeling.md
+++ b/docs/beginner-guides/labeling.md
@@ -1,3 +1,4 @@
+(labeling)=
# Labeling GUI
## Selecting Frames to Label
@@ -45,7 +46,9 @@ Alright, you've got your extracted frames ready. Now comes the labeling!
- **Navigate Through Frames:** Use the slider to go from one frame to the next after you're done labeling.
- **Save Progress:** Remember to save your work as you go with **`Command and S`** (or **`Ctrl and S`** on Windows).
-> 💡 **Note:** For a detailed walkthrough on using the Napari labeling GUI, have a look at the [DeepLabCut Napari Guide](https://deeplabcut.github.io/DeepLabCut/docs/napari_GUI.html). Additionally, you can watch our instructional [YouTube video](https://www.youtube.com/watch?v=hsA9IB5r73E) for more insights and tips.
+> 💡 **Note:** For a detailed walkthrough on using the Napari labeling GUI, have a look at the
+[DeepLabCut Napari Guide](napari-gui). Additionally, you can watch our instructional
+[YouTube video](https://www.youtube.com/watch?v=hsA9IB5r73E) for more insights and tips.
### Completing the Set
diff --git a/docs/beginner-guides/manage-project.md b/docs/beginner-guides/manage-project.md
index a53789360f..3f26589ae2 100644
--- a/docs/beginner-guides/manage-project.md
+++ b/docs/beginner-guides/manage-project.md
@@ -41,4 +41,4 @@ A **`Configuration Editor`** window will open, displaying all the configuration
- **Save the Configuration:** Once you're satisfied with the modifications, click **`Save`**. This will store your changes and return you to the main GUI window.
-## Next, head over the beginner guide for [Labeling your data](https://deeplabcut.github.io/DeepLabCut/docs/labelling)
+## Next, head over the beginner guide for [Labeling your data](labeling)
diff --git a/docs/beginner-guides/video-analysis.md b/docs/beginner-guides/video-analysis.md
index 59c976360d..849d8b6638 100644
--- a/docs/beginner-guides/video-analysis.md
+++ b/docs/beginner-guides/video-analysis.md
@@ -16,7 +16,7 @@ After training and evaluating your model, the next step is to apply it to your v
- **Find Results in Your Project Folder:** After analysis, go to your project's video folder.
- **Analysis Files:** Look also for a `.metapickle`, an `.h5`, and possibly a `.csv` file for detailed analysis data.
-- **Review the Plot Poses Subfolder:** This contains visual outputs of the video analysis.
+- **Review the "plot-poses" subfolder:** This contains visual outputs of the video analysis.

diff --git a/docs/benchmark.md b/docs/benchmark.md
index 114568decd..612e307d0f 100644
--- a/docs/benchmark.md
+++ b/docs/benchmark.md
@@ -32,4 +32,4 @@ benchmarks. For an example of how to implement a benchmark submission, refer to
.. automodule:: deeplabcut.benchmark.metrics
:members:
:show-inheritance:
-```
\ No newline at end of file
+```
diff --git a/docs/convert_maDLC.md b/docs/convert_maDLC.md
index 6671afb1ef..19dd017692 100644
--- a/docs/convert_maDLC.md
+++ b/docs/convert_maDLC.md
@@ -1,10 +1,12 @@
(convert-maDLC)=
-# How to convert a pre-2.2 project for use with DeepLabCut 2.2
+# How to convert a pre-2.2 project for use with DeepLabCut 2.2 or later
-If you have a pre-2.2 project (`labeled-data`) with a **single animal** that you want to use with a multianimal project in DLC 2.2, i.e. use your older data to now train the new multi-task deep neural network, here is what you need to do.
+If you have a pre-2.2 project (`labeled-data`) with a **single animal** that you want to use with a multianimal project
+in DLC 2.2 or later, i.e. use your older data to now train the new multi-task deep neural network, here is what you
+need to do.
(1) We recommend you make a back-up of your project folder.
@@ -14,7 +16,8 @@ If you have a pre-2.2 project (`labeled-data`) with a **single animal** that you
-- After `task, scorer, date, project_path` please add the following (i.e. in the image above, you would start adding below line 6) Note, the ordering isn't important but useful to keep consistent with the template:
+- After `task, scorer, date, project_path` please add the following (i.e. in the image above, you would start adding
+below line 6) Note, the ordering isn't important but useful to keep consistent with the template:
```python
multianimalproject: true
@@ -29,9 +32,12 @@ individuals:
- mouse1
```
-- `"uniquebodyparts: []` can stay blank, unless you have other items labeled you want to estimate (consider these as similar to bodyparts in pre-2.2); i.e. corners of a box, etc. All unique bodyparts should not be connected to the multianimal bodyparts in the skeleton you will eventually make. But see "advanced option" below.
+- `"uniquebodyparts: []` can stay blank, unless you have other items labeled you want to estimate (consider these as
+similar to bodyparts in pre-2.2); i.e. corners of a box, etc. All unique bodyparts should not be connected to the
+multianimal bodyparts in the skeleton you will eventually make. See "advanced option" below.
-- Please move your "bodyparts:" to "multianimalbodyparts:" (bodypart names must stay the same!) These are the parts that will always be interconnected fully!
+- Please move your "bodyparts:" to "multianimalbodyparts:" (bodypart names must stay the same!) These are the parts
+that will always be interconnected fully!
```python
multianimalbodyparts:
- snout
@@ -46,20 +52,25 @@ then you can set `bodyparts: MULTI!`
deeplabcut.convert2_maDLC(path_config_file, userfeedback=True)
```
-Now you will see that your data within `labeled-data` are converted to a new format, and the single animal format was saved for you under a new file named `CollectedData_ ...singleanimal.h5` and `.csv` as a back-up!
+Now you will see that your data within `labeled-data` are converted to a new format, and the single animal format was
+saved for you under a new file named `CollectedData_ ...singleanimal.h5` and `.csv` as a back-up!
-(4) We strongly recommend to first run check_labels and verify that the conversion was as expected before creating a multianimal training dataset. For instance, you can load this project `config.yaml` in the Project Manager GUI and check labels then create a multi-animal training set with
+(4) We strongly recommend to first run check_labels and verify that the conversion was as expected before creating a
+multianimal training dataset. For instance, you can load this project `config.yaml` in the Project Manager GUI and
+check labels then create a multi-animal training set with
```python
deeplabcut.create_multianimaltraining_dataset(path_config_file)
```
to begin training.
-**Advanced option:** You can also assign former `bodyparts` to either `uniquebodyparts` or `multianimalbodyparts` (you can even leave some unassigned, which means they will be dropped in the conversion).
+**Advanced option:** You can also assign former `bodyparts` to either `uniquebodyparts` or `multianimalbodyparts`
+(you can even leave some unassigned, which means they will be dropped in the conversion).
Example: Imagine you had a project with the moon and a rocket with two parts labeled:
`bodyparts: [moon, rocket_tip,rocket_bottom]`
-Now you want to use this former project (labeled-data) and work on a new dataset (videos) with one moon but multiple (3) rockets. Then convert it as follows:
+Now you want to use this former project (labeled-data) and work on a new dataset (videos) with one moon but multiple
+(3) rockets. Then convert it as follows:
```
individuals: [rocket1, rocket2, rocket3]
uniquebodyparts: [moon]
diff --git a/docs/course.md b/docs/course.md
index 8d3da73610..5a254b9fa7 100644
--- a/docs/course.md
+++ b/docs/course.md
@@ -1,5 +1,10 @@
## DeepLabCut Self-paced Course
+::::{warning}
+This course was designed for DLC 2.
+An updated version for DLC 3 is in the works.
+::::
+
Do you have video of animal behaviors? Step 1: Get Poses ...
@@ -17,7 +22,7 @@ We expect it to take *roughly* 1-2 weeks to get through if you do it rigorously.
## Installation:
You need Python and DeepLabCut installed!
-- [See these "beginner docs" for help!](https://deeplabcut.github.io/DeepLabCut/docs/beginners-guide.html)
+- [See these "beginner docs" for help!](beginners-guide)
- **WATCH:** overview of conda: [Python Tutorial: Anaconda - Installation and Using Conda](https://www.youtube.com/watch?v=YJC6ldI3hWk)
@@ -32,10 +37,8 @@ You need Python and DeepLabCut installed!
- **Learning:** learning and teaching signal processing, and overview from Prof. Demba Ba [talk at JupyterCon](https://www.youtube.com/watch?v=ywz-LLYwkQQ)
-- **Learning:** Watch a talk from Alexander Mathis (a lead DeepLabCut developer) [talk about DeepLabCut!](https://www.youtube.com/watch?v=ZjWPHM0sL4E)
-
- **DEMO:** Can I DEMO DEEPLABCUT (DLC) quickly?
- - Yes: [you can click through this DEMO notebook](https://github.com/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb)
+ - Yes: [you can click through this DEMO notebook](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb)
- AND follow along with me: [Video Tutorial!](https://www.youtube.com/watch?v=DRT-Cq2vdWs)
@@ -55,9 +58,9 @@ You need Python and DeepLabCut installed!
**What you need:** any videos where you can see the animals/objects, etc.
You can use our demo videos, grab some from the internet, or use whatever older data you have. Any camera, color/monochrome, etc will work. Find diverse videos, and label what you want to track well :)
-- IF YOU ARE PART OF THE COURSE: you will be contributing to the DLC Model Zoo :smile:
+- IF YOU ARE PART OF THE COURSE: you will be contributing to the DLC Model Zoo 😊
- - **Slides:** [Overview of starting new projects](https://github.com/DeepLabCut/DeepLabCut-Workshop-Materials/blob/master/part1-labeling.pdf)
+ - **Slides:** [Overview of starting new projects](https://github.com/DeepLabCut/DeepLabCut-Workshop-Materials/blob/main/part1-labeling.pdf)
- **READ ME PLEASE:** [DeepLabCut, the science](https://rdcu.be/4Rep)
- **READ ME PLEASE:** [DeepLabCut, the user guide](https://rdcu.be/bHpHN)
- **WATCH:** Video tutorial 1: [using the Project Manager GUI](https://www.youtube.com/watch?v=KcXogR-p5Ak)
@@ -71,12 +74,12 @@ You can use our demo videos, grab some from the internet, or use whatever older
### **Module 2: Neural Networks**
- - **Slides:** [Overview of creating training and test data, and training networks](https://github.com/DeepLabCut/DeepLabCut-Workshop-Materials/blob/master/part2-network.pdf)
+ - **Slides:** [Overview of creating training and test data, and training networks](https://github.com/DeepLabCut/DeepLabCut-Workshop-Materials/blob/main/part2-network.pdf)
- **READ ME PLEASE:** [What are convolutional neural networks?](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53)
- **READ ME PLEASE:** Here is a new paper from us describing challenges in robust pose estimation, why PRE-TRAINING really matters - which was our major scientific contribution to low-data input pose-estimation - and it describes new networks that are available to you. [Pretraining boosts out-of-domain robustness for pose estimation](https://paperswithcode.com/paper/pretraining-boosts-out-of-domain-robustness)
- - **MORE DETAILS:** ImageNet: check out the original paper and dataset: http://www.image-net.org/ (link to [ppt from Dr. Fei-Fei Li](http://www.image-net.org/papers/ImageNet_2010.ppt))
+ - **MORE DETAILS:** ImageNet: check out the original paper and dataset: http://www.image-net.org/
- **REVIEW PAPER:** [A Primer on Motion Capture with Deep Learning: Principles, Pitfalls and Perspectives](https://www.sciencedirect.com/science/article/pii/S0896627320307170)
@@ -84,11 +87,11 @@ You can use our demo videos, grab some from the internet, or use whatever older
Before you create a training/test set, please read/watch:
- - **More information:** [Which types neural networks are available, and what should I use?](https://github.com/AlexEMG/DeepLabCut/wiki/What-neural-network-should-I-use%3F)
+ - **More information:** [Which types neural networks are available, and what should I use?](https://github.com/DeepLabCut/DeepLabCut/wiki/What-neural-network-should-I-use%3F-(Trade-offs,-speed-performance,-and-considerations))
- **WATCH:** Video tutorial 1: [How to test different networks in a controlled way](https://www.youtube.com/watch?v=WXCVr6xAcCA)
- Now, decide what model(s) you want to test.
- IF you want to train on your CPU, then run the step `create_training_dataset`, in the GUI etc. on your own computer.
- - IF you want to use GPUs on google colab, [**(1)** watch this FIRST/follow along here!](https://www.youtube.com/watch?v=qJGs8nxx80A) **(2)** move your whole project folder to Google Drive, and then [**use this notebook**](https://github.com/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb)
+ - IF you want to use GPUs on google colab, [**(1)** watch this FIRST/follow along here!](https://www.youtube.com/watch?v=qJGs8nxx80A) **(2)** move your whole project folder to Google Drive, and then [**use this notebook**](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb)
**MODULE 2 webinar**: https://youtu.be/ILsuC4icBU0
diff --git a/docs/deeplabcutlive.md b/docs/deeplabcutlive.md
index f8784b651d..f1edbff3cc 100644
--- a/docs/deeplabcutlive.md
+++ b/docs/deeplabcutlive.md
@@ -1,3 +1,4 @@
+(deeplabcut-live)=
# DeepLabCut-Live!
We provide two additional pip packages that allow you to record and stream camera data and run DeeplabCut models in real-time.
diff --git a/docs/gui/PROJECT_GUI.md b/docs/gui/PROJECT_GUI.md
index 106e7eca5c..e0883c324e 100644
--- a/docs/gui/PROJECT_GUI.md
+++ b/docs/gui/PROJECT_GUI.md
@@ -10,7 +10,7 @@ As some users may be more comfortable working with an interactive interface, we
(1) Install DeepLabCut using the simple-install with Anaconda found [here!](how-to-install)*.
Now you have DeepLabCut installed, but if you want to update it, either follow the prompt in the GUI which will ask you to upgrade when a new version is available, or just go into your env (activate DEEPLABCUT) then run:
-` pip install 'deeplabcut[gui,modelzoo]'` *but please see [full install guide](https://deeplabcut.github.io/DeepLabCut/docs/installation.html)!
+` pip install 'deeplabcut[gui,modelzoo]'` *but please see [full install guide](how-to-install)!
(2) Open the terminal and run: `python -m deeplabcut`
@@ -23,15 +23,14 @@ Now you have DeepLabCut installed, but if you want to update it, either follow t
Start at the Project Management Tab and work your way through the tabs to built your customized model and deploy it on new data.
We recommend to keep the terminal visible (as well as the GUI) so you can see the ongoing processes as you step through your project, or any errors that might arise.
-- For specific napari-based labeling features, see the ["napari gui" docs](https://deeplabcut.github.io/DeepLabCut/docs/napari_GUI.html#usage).
+- For specific napari-based labeling features, see the ["napari gui" docs](napari-gui-usage).
- To change from dark to light mode, set appearance at the top:
-
-## VIDEO DEMOS: How to launch and run the Project Manager GUI:
+## Video Demos: How to launch and run the Project Manager GUI:
**Click on the images!**
@@ -39,17 +38,17 @@ Note that currently the video demo is the wxPython version, but the logic is the
[](https://youtu.be/KcXogR-p5Ak)
-### Using the Project Manager GUI with the latest DLC code (single animals, plus objects): :arrow_down:
+### Using the Project Manager GUI with the latest DLC code (single animals, plus objects): ⬇️
[](https://www.youtube.com/watch?v=JDsa8R5J0nQ)
-[READ MORE HERE](important-info-regd-usage)
+[Read more here](important-info-regd-usage)
### Using the Project Manager GUI with the latest DLC code (multiple identical-looking animals, plus objects):
[](https://www.youtube.com/watch?v=Kp-stcTm77g)
-[READ MORE HERE](important-info-regd-usage)
+[Read more here](important-info-regd-usage)
## VIDEO DEMO: How to benchmark your data with the new networks and data augmentation pipelines:
diff --git a/docs/gui/napari_GUI.md b/docs/gui/napari_GUI.md
index 5fe4fb84c5..4875b23237 100644
--- a/docs/gui/napari_GUI.md
+++ b/docs/gui/napari_GUI.md
@@ -1,3 +1,4 @@
+(napari-gui)=
# napari labeling GUI
We replaced wxPython with PySide6 + as of version 2.3. Here is how to use the napari-aspects of the new GUI. It is available in napari-hub as a stand alone GUI as well as integrated into our main GUI, [please see docs here](https://deeplabcut.github.io/DeepLabCut/docs/PROJECT_GUI.html).
@@ -30,6 +31,7 @@ To install latest development version:
` pip install git+https://github.com/DeepLabCut/napari-deeplabcut.git `
+(napari-gui-usage)=
## Usage
To use the full GUI, please run:
@@ -46,8 +48,7 @@ All accepted files (`config.yaml`, images, `.h5` data files) can be loaded eithe
The easiest way to get started is to drop a folder (typically a folder from within a DeepLabCut's `labeled-data` directory), and, if labeling from scratch, drop the corresponding `config.yaml` to automatically add a `Points layer` and populate the dropdown menus.
-[🎥 DEMO
-](https://youtu.be/hsA9IB5r73E)
+[🎥 DEMO](https://youtu.be/hsA9IB5r73E)
**Tools & shortcuts are:**
diff --git a/docs/images/box1-multi.png b/docs/images/box1-multi.png
new file mode 100644
index 0000000000..2da7320667
Binary files /dev/null and b/docs/images/box1-multi.png differ
diff --git a/docs/images/box1-single.png b/docs/images/box1-single.png
new file mode 100644
index 0000000000..c5c802d486
Binary files /dev/null and b/docs/images/box1-single.png differ
diff --git a/docs/installation.md b/docs/installation.md
index 6e9e2b0f26..42c592a6f7 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -6,6 +6,41 @@
- We recommend for most users to use our supplied CONDA environment.
- Please note, you will get the best performance with using a GPU! Please see the section on [GPU support](https://deeplabcut.github.io/DeepLabCut/docs/installation.html#gpu-support) to install your GPU driver and CUDA.
+````{admonition} Familiar with python packages and conda? Quick install here.
+:class: dropdown
+
+This assumes you have `conda`/`mamba` installed and installs DeepLabCut in a fresh
+environment. If you have an NVIDIA GPU, install PyTorch according to [their instructions
+](https://pytorch.org/get-started/locally/) (with your desired CUDA version) - you just
+need your GPU drivers installed.
+
+```bash
+conda create -n DEEPLABCUT python=3.10
+conda activate DEEPLABCUT
+conda install -c conda-forge pytables==3.8.0
+
+# install torch with your desired CUDA version (or CPU) - check their website
+# for the exact command
+pip install torch torchvision
+
+# install the latest version of DeepLabCut
+pip install --pre deeplabcut
+# or if you want to use the GUI
+pip install deeplabcut[gui]
+
+# ONLY IF YOU HAVE A CUDA GPU - check that PyTorch can access your GPU; this
+# should print `True`
+python -c "import torch; print(torch.cuda.is_available())"
+```
+
+Why do we install [pytables](https://www.pytables.org/usersguide/installation.html) with
+`conda` and not `pip`? Because it requires some libraries that not all users will have
+installed, and conda will ensure that they are installed as well.
+
+If you're familiar with the command line and want TensorFlow support, look [below](
+deeplabcut-with-tf-install) for a fresh installation that has worked for us (on Linux)
+and makes it possible to use the GPU with both PyTorch and TensorFlow.
+````
## CONDA: The installation process is as easy as this figure! -->
@@ -24,15 +59,15 @@
- **Apple M-chip GPU?** Be sure to install miniconda3, and your GPU will be used by default.
````
-
### Step 1: Install Python via Anaconda
-#### Install [anaconda](https://www.anaconda.com/distribution/), or use miniconda3 for MacOS users (see below)
+#### Install [anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html#), or use miniconda3 for MacOS users (see below)
- Anaconda is an easy way to install Python and additional packages across various operating systems. With Anaconda you create all the dependencies in an [environment](https://conda.io/docs/user-guide/tasks/manage-environments.html) on your machine.
```{Hint}
-Download anaconda for your operating system: https://www.anaconda.com/distribution/.
+Download anaconda for your operating system: [anaconda.com/download/
+](https://www.anaconda.com/download/)
```
- IF you use a M1 or M2 chip in your MacBook with v12.5+ (typically 2020 or newer machines), we recommend **miniconda3,** which operates with the same principles as anaconda. This is straight forward and explained in detail here: https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html. But in short, open the program "terminal" and copy/paste and run the code that is supplied below.
@@ -83,11 +118,46 @@ Be sure you are in the folder that has the `.yaml` file, then run:
Now you should see (`nameofenv`) on the left of your terminal screen, i.e. ``(DEEPLABCUT) YourName-MacBook...``
NOTE: no need to run pip install deeplabcut, as it is already installed!!! :)
+(deeplabcut-with-tf-install)=
#### 💡 Notice: PyTorch and TensorFlow Support within DeepLabCut
-```{Hint}
+
+````{admonition} DeepLabCut TensorFlow Support
:class: dropdown
-As of June 2024 we have a PyTorch Engine backend and we will be depreciating the TensorFlow backend by the end of 2024. Currently, if you want to use TensorFlow, you need to run `pip install deeplabcut[tf]` in order to install the correct version of TensorFlow in your conda env. Please note, we will be providing bug fixes, but we will not be supporting new TensorFlow versions beyond 2.10 (Windows), and 2.12 for other OS.
+As of June 2024 we have a PyTorch Engine backend and we will be depreciating the
+TensorFlow backend by the end of 2024. Currently, if you want to use TensorFlow, you
+need to run `pip install deeplabcut[tf]` in order to install the correct version of
+TensorFlow in your conda env. Please note, we will be providing bug fixes, but we will
+not be supporting new TensorFlow versions beyond 2.10 (Windows), and 2.12 for other OS.
+
+Installing TensorFlow and getting it to have access to the GPU can be a bit tricky.
+Check TensorFlow's [compatibility matrix](https://www.tensorflow.org/install/source#gpu)
+to know which version of CUDA and cuDNN you should install.
+
+We have found that installing DeepLabCut with the following commands works well for
+Linux users to install PyTorch 2.3.1, TensorFlow 2.12, CUDA 11.8 and cuDNN 8 in a Conda
+environment:
+
+```bash
+conda create -n deeplabcut-with-tf "python=3.10"
+conda activate deeplabcut-with-tf
+
+# Install the desired TensorFlow version, built for CUDA 11.8 and cuDNN 8
+pip install "tensorflow==2.12" "tensorpack>=0.11" "tf_slim>=1.1.0"
+
+# Install PyTorch with a version using CUDA 11.8 and cuDNN 8
+pip install "torch==2.3.1" torchvision --index-url https://download.pytorch.org/whl/cu118
+
+# Create symbolic links to NVIDIA shared libraries for TensorFlow
+# -> as described in their installation docs:
+# https://www.tensorflow.org/install/pip#step-by-step_instructions
+
+pushd $(dirname $(python -c 'print(__import__("tensorflow").__file__)'))
+ln -svf ../nvidia/*/lib/*.so* .
+popd
+
+pip install --pre deeplabcut
```
+````
**Great, that's it! DeepLabCut is installed!** 🎉💜
@@ -109,7 +179,7 @@ To git clone type: ``git clone https://github.com/DeepLabCut/DeepLabCut.git``).
### PIP:
-- Everything you need to build custom models within DeepLabCut (i.e., use our source code and our dependencies) can be installed with `pip install 'deeplabcut[gui]'` (for GUI support w/tensorflow) or without the gui: `pip install 'deeplabcut'`.
+- Everything you need to build custom models within DeepLabCut (i.e., use our source code and our dependencies) can be installed with `pip install 'deeplabcut[gui]'` (for GUI support w/PyTorch) or without the gui: `pip install 'deeplabcut'`.
- If you want to use the SuperAnimal models, then please use `pip install 'deeplabcut[gui,modelzoo]'`.
## DOCKER:
@@ -118,13 +188,29 @@ To git clone type: ``git clone https://github.com/DeepLabCut/DeepLabCut.git``).
## Pro Tips:
-More [installation ProTips](installTips) are also available.
+More [installation ProTips](installation-tips) are also available.
-If you ever want to update your DLC, just run `pip install --upgrade deeplabcut` once you are inside your env. If you want to use a specific release, then you need to specify the version you want, such as `pip install deeplabcut==2.2`. Once installed, you can check the version by running `import deeplabcut` `deeplabcut.__version__`. Don't be afraid to update, DLC is backwards compatible with your 2.0+ projects and performance continues to get better and new features are added nearly monthly.
+If you ever want to update your DLC, just run `pip install --upgrade deeplabcut` once
+you are inside your env. If you want to use a specific release, then you need to specify
+the version you want, such as `pip install deeplabcut==3.0`. Once installed, you can
+check the version by running `import deeplabcut` `deeplabcut.__version__`. Don't be
+afraid to update, DLC is backwards compatible with your 2.0+ projects and performance
+continues to get better and new features are added nearly monthly.
-Here are some conda environment management tips: https://kapeli.com/cheat_sheets/Conda.docset/Contents/Resources/Documents/index
+**All of the data you labelled in version 2.X is also compatible with version 3+ and the
+PyTorch engine**! There is no change in the workflow or the way labels are handled: the
+big changes happen under-the-hood! If you've been working with DeepLabCut 2.X and want
+to learn more about moving to the PyTorch engine, checkout our docs on [moving from
+TensorFlow to PyTorch](dlc3-user-guide)
-**Pro Tip:** If you want to modify code and then test it, you can use our provided testscripts. This would mean you need to be up-to-date with the latest GitHub-based code though! Please see [here](installTips) on how to get the latest GitHub code, and how to test your installation by following this video: https://www.youtube.com/watch?v=IOWtKn3l33s.
+Here are some conda environment management tips: [kapeli.com: Conda Cheat Sheet](
+https://kapeli.com/cheat_sheets/Conda.docset/Contents/Resources/Documents/index)
+
+**Pro Tip:** If you want to modify code and then test it, you can use our provided
+testscripts. This would mean you need to be up-to-date with the latest GitHub-based code
+though! Please see [here](installation-tips) on how to get the latest GitHub code, and
+how to test your installation by following this video:
+https://www.youtube.com/watch?v=IOWtKn3l33s.
## Creating your own customized conda env (recommended route for Linux: Ubuntu, CentOS, Mint, etc.)
@@ -136,7 +222,9 @@ In the terminal type:
`conda create -n DLC python=3.10`
-**Current version:** The only thing you then need to add to the env is deeplabcut (`pip install deeplabcut`) or `pip install 'deeplabcut[gui]'` which has a napari based GUI.
+**Current version:** The only thing you then need to add to the env is deeplabcut (
+`pip install deeplabcut`) or `pip install 'deeplabcut[gui]'` which has a napari based
+GUI.
## **GPU Support:**
@@ -145,36 +233,51 @@ The ONLY thing you need to do **first** if you have an NVIDIA GPU and the matchi
- CUDA: https://developer.nvidia.com/cuda-downloads (just follow the prompts here!)
- DRIVERS: https://www.nvidia.com/Download/index.aspx
-#### The most common "new user" hurdle is installing and using your GPU, so don't get discouraged!
+### The most common "new user" hurdle is installing and using your GPU, so don't get discouraged!
-**CRITICAL:** If you have a GPU, you should FIRST **install the NVIDIA CUDA package and an appropriate driver for your specific GPU**, then you can use the supplied conda file. Please follow the instructions found here https://www.tensorflow.org/install/gpu, and more tips below, to install the correct version of CUDA and your graphic card driver. The order of operations matters.
+**CRITICAL:** If you have a GPU, you should FIRST **install an appropriate driver for
+your specific GPU**, then you can use the supplied conda file. You'll need an NVIDIA GPU
+which is compatible with CUDA. To see a list of CUDA-enabled NVIDIA GPUs, please [see
+their website](https://developer.nvidia.com/cuda-gpus).
-- Here we provide notes on how to install and check your GPU use with TensorFlow (which is used by DeepLabCut and already installed with the Anaconda files above). Thus, you do not need to independently install tensorflow.
+- Here we provide notes on how to install and check your GPU use with TensorFlow (which
+is used by DeepLabCut and already installed with the Anaconda files above). Thus, you do
+not need to independently install tensorflow.
+**FIRST**, install a driver for your GPU. Find DRIVER HERE:
+https://www.nvidia.com/download/index.aspx
-**FIRST**, install a driver for your GPU. Find DRIVER HERE: https://www.nvidia.com/download/index.aspx
-- check which driver is installed by typing this into the terminal: ``nvidia-smi``.
+- Check which driver is installed by typing this into the terminal: ``nvidia-smi``.
**SECOND**, install CUDA: https://developer.nvidia.com/ (Note that cuDNN, https://developer.nvidia.com/cudnn, is supplied inside the anaconda environment files, so you don't need to install it again).
**THIRD:** Follow the steps above to get the `DEEPLABCUT` conda file and install it!
-##### Notes:
-
- - **As of version 3.0+ we moved to PyTorch. The Last supported version of TensorFlow is 2.10 (window users) and 2.12 for others (we have not tested beyond this).**
- - Please be mindful different versions of TensorFlow require different CUDA versions.
- - As the combination of TensorFlow and CUDA matters, we strongly encourage you to **check your driver/cuDNN/CUDA/TensorFlow versions** [on this StackOverflow post](https://stackoverflow.com/questions/30820513/what-is-version-of-cuda-for-nvidia-304-125/30820690#30820690).
- - To check your GPU is working, in the terminal, run:
-
- `nvcc -V` to check your installed version(s).
-
-- The best practice is to then run the supplied `testscript.py` (this is inside the examples folder you acquired when you git cloned the repo). Here is more information/a short [video on running the testscript](https://www.youtube.com/watch?v=IOWtKn3l33s).
-
-- Additionally, if you want to use the bleeding edge, with yout git clone you also get the latest code. While inside the main DeepLabCut folder, you can run `./reinstall.sh` to be sure it's installed (more here: https://github.com/DeepLabCut/DeepLabCut/wiki/How-to-use-the-latest-GitHub-code)
-
-- You can test that your GPU is being properly engaged with these additional [tips](https://www.tensorflow.org/programmers_guide/using_gpu).
-
-- Ubuntu users might find this [installation guide](https://deeplabcut.github.io/DeepLabCut/docs/recipes/installTips.html#installation-on-ubuntu-20-04-lts) for a fresh ubuntu install useful as well.
+#### Notes:
+
+- **As of version 3.0+ we moved to PyTorch. The Last supported version of TensorFlow is
+2.10 (window users) and 2.12 for others (we have not tested beyond this).**
+- Please be mindful different versions of TensorFlow require different CUDA versions.
+- As the combination of TensorFlow and CUDA matters, we strongly encourage you to
+**check your driver/cuDNN/CUDA/TensorFlow versions** [on this StackOverflow post](
+https://stackoverflow.com/questions/30820513/what-is-version-of-cuda-for-nvidia-304-125/30820690#30820690
+).
+- To check your GPU is working, in the terminal, run:
+
+`nvcc -V` to check your installed version(s).
+
+- The best practice is to then run the supplied `testscript_pytorch_single_animal.py`
+(or `testscript.py` for the TensorFlow engine); this is inside the examples folder you
+acquired when you git cloned the repo. Here is more information/a short
+[video on running the testscript](https://www.youtube.com/watch?v=IOWtKn3l33s).
+- Additionally, if you want to use the bleeding edge, with yout git clone you also get
+the latest code. While inside the main DeepLabCut folder, you can run `./reinstall.sh`
+to be sure it's installed (more [here](installation-tips))
+- You can test that your GPU is being properly engaged with these additional [tips](
+https://www.tensorflow.org/programmers_guide/using_gpu).
+- Ubuntu users might find this [installation guide](
+https://deeplabcut.github.io/DeepLabCut/docs/recipes/installTips.html#installation-on-ubuntu-20-04-lts
+) for a fresh ubuntu install useful as well.
## Troubleshooting:
@@ -230,7 +333,3 @@ If you perform the system-wide installation, and the computer has other Python p
- If you want to use a pre3.0 version, you will need [TensorFlow](https://www.tensorflow.org/) (we used version 1.0 in the Nature Neuroscience paper, later versions also work with the provided code (we tested **TensorFlow versions 1.0 to 1.15, and 2.0 to 2.10**; we recommend TF2.10 now) for Python 3.8, 3.9, 3.10 with GPU support.
- To note, is it possible to run DeepLabCut on your CPU, but it will be VERY slow (see: [Mathis & Warren](https://www.biorxiv.org/content/early/2018/10/30/457242)). However, this is the preferred path if you want to test DeepLabCut on your own computer/data before purchasing a GPU, with the added benefit of a straightforward installation! Otherwise, use our COLAB notebooks for GPU access for testing.
- Docker: We highly recommend advanced users use the supplied [Docker container](docker-containers)
-
-
-
-Return to [readme](readme).
diff --git a/docs/maDLC_UserGuide.md b/docs/maDLC_UserGuide.md
index 8f6ca7ab39..8f5b5a964e 100644
--- a/docs/maDLC_UserGuide.md
+++ b/docs/maDLC_UserGuide.md
@@ -4,9 +4,16 @@
This document should serve as the user guide for maDLC,
and it is here to support the scientific advances presented in [Lauer et al. 2022](https://doi.org/10.1038/s41592-022-01443-0).
-
Note, we strongly encourage you to use the [Project Manager GUI](project-manager-gui) when you first start using multi-animal mode. Each tab is customized for multi-animal when you create or load a multi-animal project. As long as you follow the recommendations within the GUI, you should be good to go!
+````{versionadded} 3.0.0
+PyTorch is now available as a deep learning engine for pose estimation models, along
+with new model architectures! For more information about moving from TensorFlow to
+PyTorch (if you're already familiar with DeepLabCut & the TensorFlow engine),
+check out [the PyTorch user guide](dlc3-user-guide). If you're just starting
+out with DeepLabCut, we suggest you use the PyTorch backend.
+````
+
## How to think about using maDLC:
You should think of maDLC being **four** parts.
@@ -21,9 +28,9 @@ Thus, you should always label, train, and evaluate the pose estimation performan
## Install:
-**Quick start:** If you are using DeepLabCut on the cloud, or otherwise cannot use the GUIs and you should install with: `pip install 'deeplabcut'`; if you need GUI support, please use: `pip install 'deeplabcut[gui]'`.
+**Quick start:** If you are using DeepLabCut on the cloud, or otherwise cannot use the GUIs and you should install with: `pip install 'deeplabcut'`; if you need GUI support, please use: `pip install 'deeplabcut[gui]'`. Check the [installation page](how-to-install) for more information, including GPU support.
-IF you want to use the bleeding edge version to make edits to the code, see here on how to install it and test it (https://deeplabcut.github.io/DeepLabCut/docs/recipes/installTips.html#how-to-use-the-latest-updates-directly-from-github).
+IF you want to use the bleeding edge version to make edits to the code, see [here on how to install it and test it](https://deeplabcut.github.io/DeepLabCut/docs/recipes/installTips.html#how-to-use-the-latest-updates-directly-from-github).
## Get started in the terminal or Project GUI:
@@ -35,7 +42,7 @@ Then follow the tabs! It might be useful to read the following, however, so you
```{Hint}
🚨 If you use Windows, please always open the terminal with administrator privileges! Right click, and "run as administrator".
```
- Please read more [here](https://github.com/DeepLabCut/Docker4DeepLabCut2.0), and in our Nature Protocols paper [here](https://www.nature.com/articles/s41596-019-0176-0). And, see our [troubleshooting wiki](https://github.com/DeepLabCut/DeepLabCut/wiki/Troubleshooting-Tips).
+ Please read more [here](https://deeplabcut.github.io/DeepLabCut/docs/docker.html), and in our Nature Protocols paper [here](https://www.nature.com/articles/s41596-019-0176-0). And, see our [troubleshooting wiki](https://github.com/DeepLabCut/DeepLabCut/wiki/Troubleshooting-Tips).
Open an ``ipython`` session and import the package by typing in the terminal:
```python
@@ -43,20 +50,25 @@ ipython
import deeplabcut
```
-```{TIP:}
+```{TIP}
for every function there is a associated help document that can be viewed by adding a **?** after the function name; i.e. ``deeplabcut.create_new_project?``. To exit this help screen, type ``:q``.
```
-### Create a New Project:
+### (A) Create a New Project
```python
-deeplabcut.create_new_project('ProjectName','YourName', ['/usr/FullPath/OfVideo1.avi', '/usr/FullPath/OfVideo2.avi', '/usr/FullPath/OfVideo1.avi'],
- copy_videos=True, multianimal=True)
+deeplabcut.create_new_project(
+ "ProjectName",
+ "YourName",
+ ["/usr/FullPath/OfVideo1.avi", "/usr/FullPath/OfVideo2.avi", "/usr/FullPath/OfVideo1.avi"],
+ copy_videos=True,
+ multianimal=True,
+)
```
-Tip: if you want to place the project folder somewhere specific, please also pass : ``working_directory = 'FullPathOftheworkingDirectory'``
+Tip: if you want to place the project folder somewhere specific, please also pass : ``working_directory = "FullPathOftheworkingDirectory"``
-- Note, if you are a linux/macos user the path should look like: ``['/home/username/yourFolder/video1.mp4']``; if you are a Windows user, it should look like: ``[r'C:\username\yourFolder\video1.mp4']``
+- Note, if you are a linux/macOS user the path should look like: ``["/home/username/yourFolder/video1.mp4"]``; if you are a Windows user, it should look like: ``[r"C:\username\yourFolder\video1.mp4"]``
- Note, you can also put ``config_path = `` in front of the above line to create the path to the config.yaml that is used in the next step, i.e. ``config_path=deeplabcut.create_project(...)``)
- If you do not, we recommend setting a variable so this can be easily used! Once you run this step, the config_path is printed for you once you run this line, so set a variable for ease of use, i.e. something like:
```python
@@ -64,9 +76,20 @@ config_path = '/thefulloutputpath/config.yaml'
```
- just be mindful of the formatting for Windows vs. Unix, see above.
-This set of arguments will create a project directory with the name **Name of the project+name of the experimenter+date of creation of the project** in the **Working directory** and creates the symbolic links to videos in the **videos** directory. The project directory will have subdirectories: **dlc-models**, **labeled-data**, **training-datasets**, and **videos**. All the outputs generated during the course of a project will be stored in one of these subdirectories, thus allowing each project to be curated in separation from other projects. The purpose of the subdirectories is as follows:
-
-**dlc-models:** This directory contains the subdirectories *test* and *train*, each of which holds the meta information with regard to the parameters of the feature detectors in configuration files. The configuration files are YAML files, a common human-readable data serialization language. These files can be opened and edited with standard text editors. The subdirectory *train* will store checkpoints (called snapshots in TensorFlow) during training of the model. These snapshots allow the user to reload the trained model without re-training it, or to pick-up training from a particular saved checkpoint, in case the training was interrupted.
+This set of arguments will create a project directory with the name **Name of the project+name of the experimenter+date of creation of the project** in the **Working directory** and creates the symbolic links to videos in the **videos** directory. The project directory will have subdirectories: **dlc-models**, **dlc-models-pytorch**, **labeled-data**, **training-datasets**, and **videos**. All the outputs generated during the course of a project will be stored in one of these subdirectories, thus allowing each project to be curated in separation from other projects. The purpose of the subdirectories is as follows:
+
+**dlc-models** and **dlc-models-pytorch** have a similar structure: the first contains
+files for the TensorFlow engine while the second contains files for the PyTorch engine.
+At the top level in these directories, there are
+directories referring to different iterations of labels refinement (see below): **iteration-0**, **iteration-1**, etc.
+The refinement iterations directories store shuffle directories, each shuffle directory stores model data related to a
+particular experiment: trained and tested on a particular training and testing sets, and with a particular model
+architecture. Each shuffle directory contains the subdirectories *test* and *train*, each of which holds the meta
+information with regard to the parameters of the feature detectors in configuration files. The configuration files are
+YAML files, a common human-readable data serialization language. These files can be opened and edited with standard text
+editors. The subdirectory *train* will store checkpoints (called snapshots) during training of the model. These
+snapshots allow the user to reload the trained model without re-training it, or to pick-up training from a particular
+saved checkpoint, in case the training was interrupted.
**labeled-data:** This directory will store the frames used to create the training dataset. Frames from different videos are stored in separate subdirectories. Each frame has a filename related to the temporal index within the corresponding video, which allows the user to trace every frame back to its origin.
@@ -75,33 +98,55 @@ This set of arguments will create a project directory with the name **Name of th
**videos:** Directory of video links or videos. When **copy\_videos** is set to ``False``, this directory contains symbolic links to the videos. If it is set to ``True`` then the videos will be copied to this directory. The default is ``False``. Additionally, if the user wants to add new videos to the project at any stage, the function **add\_new\_videos** can be used. This will update the list of videos in the project's configuration file. Note: you neither need to use this folder for videos, nor is it required for analyzing videos (they can be anywhere).
```python
-deeplabcut.add_new_videos('Full path of the project configuration file*', ['full path of video 4', 'full path of video 5'], copy_videos=True/False)
+deeplabcut.add_new_videos(
+ "Full path of the project configuration file*",
+ ["full path of video 4", "full path of video 5"],
+ copy_videos=True/False,
+)
```
*Please note, *Full path of the project configuration file* will be referenced as ``config_path`` throughout this protocol.
-You can also use annotated data from singe-animal projects, by converting those files. There are docs for this: [convert single to multianimal annotation data](convert-maDLC)
+You can also use annotated data from single-animal projects, by converting those files.
+There are docs for this: [convert single to multianimal annotation data](convert-maDLC)
+
+
+
+#### API Docs
+````{admonition} Click the button to see API Docs
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_new_project.rst
+```
+````
-### Configure the Project:
+### (B) Configure the Project
-- open the **config.yaml** file (in a text editor (like atom, gedit, vim etc.)), which can be found in the subfolder created when you set your project name, to change parameters and identify label names! This is a crucial step.
+Next, open the **config.yaml** file, which was created during **create\_new\_project**.
+You can edit this file in any text editor. Familiarize yourself with the meaning of the
+parameters (Box 1). You can edit various parameters, in particular you **must add the list of *individuals* and *bodyparts* (or points of interest)**.
-Next, open the **config.yaml** file, which was created during **create\_new\_project**. You can edit this file in any text editor. Familiarize yourself with the meaning of the parameters (Box 1). You can edit various parameters, in particular you **must add the list of *bodyparts* (or points of interest)** that you want to track. You can also set the *colormap* here that is used for all downstream steps (can also be edited at anytime), like labeling GUIs, videos, etc. Here any [matplotlib colormaps](https://matplotlib.org/tutorials/colors/colormaps.html) will do!
+You can also set the *colormap* here that is used for all downstream steps (can also be edited at anytime), like labeling GUIs, videos, etc. Here any [matplotlib colormaps](https://matplotlib.org/tutorials/colors/colormaps.html) will do!
An easy way to programmatically edit the config file at any time is to use the function **edit\_config**, which takes the full path of the config file to edit and a dictionary of key–value pairs to overwrite.
-````python
-edits = {'colormap': 'summer',
- 'individuals': ['mickey', 'minnie', 'bianca'],
- 'skeleton': [['snout', 'tailbase'], ['snout', 'rightear']]}
+```python
+import deeplabcut
+
+config_path = "/path/to/project-dlc-2025-01-01/config.yaml"
+edits = {
+ "colormap": "summer",
+ "individuals": ["mickey", "minnie", "bianca"],
+ "skeleton": [["snout", "tailbase"], ["snout", "rightear"]]
+}
deeplabcut.auxiliaryfunctions.edit_config(config_path, edits)
-````
+```
Please DO NOT have spaces in the names of bodyparts, uniquebodyparts, individuals, etc.
-**ATTENTIONt:** You need to edit the config.yaml file to **modify the following items** which specify the animal ID, body parts, and any unique labels. Note, we also highly recommend that you use **more bodypoints** that you might be interested in for your experiment, i.e., labeling along the spine/tail for 8 bodypoints would be better than four. This will help the performance.
+**ATTENTION:** You need to edit the config.yaml file to **modify the following items** which specify the animal ID, bodyparts, and any unique labels. Note, we also highly recommend that you use **more bodyparts** that you might be interested in for your experiment, i.e., labeling along the spine/tail for 8 bodyparts would be better than four. This will help the performance.
Modifying the `config.yaml` is crucial:
@@ -123,6 +168,7 @@ multianimalbodyparts:
identity: True/False
```
+
**Individuals:** are names of "individuals" in the annotation dataset. These should/can be generic (e.g. mouse1, mouse2, etc.). These individuals are comprised of the same bodyparts defined by `multianimalbodyparts`. For annotation in the GUI and training, it is important that all individuals in each frame are labeled. Thus, keep in mind that you need to set individuals to the maximum number in your labeled-data set, .i.e., if there is (even just one frame) with 17 animals then the list should be `- indv1` to `- indv17`. Note, once trained if you have a video with more or less animals, that is fine - you can have more or less animals during video analysis!
**Identity:** If you can tell the animals apart, i.e., one might have a collar, or a black marker on the tail of a mouse, then you should label these individuals consistently (i.e., always label the mouse with the black marker as "indv1", etc). If you have this scenario, please set `identity: True` in your `config.yaml` file. If you have 4 black mice, and you truly cannot tell them apart, then leave this as `false`.
@@ -131,14 +177,22 @@ identity: True/False
**Uniquebodyparts:** are points that you want to track, but that appear only once within each frame, i.e. they are "unique". Typically these are things like unique objects, landmarks, tools, etc. They can also be animals, e.g. in the case where one German shepherd is attending to many sheep the sheep bodyparts would be multianimalbodyparts, the shepherd parts would be uniquebodyparts and the individuals would be the list of sheep (e.g. Polly, Molly, Dolly, ...).
-### Select Frames to Label:
+### (C) Select Frames to Label
**CRITICAL:** A good training dataset should consist of a sufficient number of frames that capture the breadth of the behavior. This ideally implies to select the frames from different (behavioral) sessions, different lighting and different animals, if those vary substantially (to train an invariant, robust feature detector). Thus for creating a robust network that you can reuse in the laboratory, a good training dataset should reflect the diversity of the behavior with respect to postures, luminance conditions, background conditions, animal identities, etc. of the data that will be analyzed. For the simple lab behaviors comprising mouse reaching, open-field behavior and fly behavior, 100−200 frames gave good results [Mathis et al, 2018](https://www.nature.com/articles/s41593-018-0209-y). However, depending on the required accuracy, the nature of behavior, the video quality (e.g. motion blur, bad lighting) and the context, more or less frames might be necessary to create a good network. Ultimately, in order to scale up the analysis to large collections of videos with perhaps unexpected conditions, one can also refine the data set in an adaptive way (see refinement below). **For maDLC, be sure you have labeled frames with closely interacting animals!**
The function `extract_frames` extracts frames from all the videos in the project configuration file in order to create a training dataset. The extracted frames from all the videos are stored in a separate subdirectory named after the video file’s name under the ‘labeled-data’. This function also has various parameters that might be useful based on the user’s need.
+
```python
-deeplabcut.extract_frames(config_path, mode='automatic/manual', algo='uniform/kmeans', userfeedback=False, crop=True/False)
+deeplabcut.extract_frames(
+ config_path,
+ mode='automatic/manual',
+ algo='uniform/kmeans',
+ userfeedback=False,
+ crop=True/False,
+)
```
+
**CRITICAL POINT:** It is advisable to keep the frame size small, as large frames increase the training and
inference time, or you might not have a large enough GPU for this.
When running the function `extract_frames`, if the parameter crop=True, then you will be asked to draw a box within the GUI (and this is written to the config.yaml file).
@@ -160,63 +214,99 @@ behaviors, and not extract the frames across the whole video. This can be achiev
parameters in the config.yaml file. Also, the user can change the number of frames to extract from each video using
the numframes2extract in the config.yaml file.
- **For maDLC, be sure you have labeled frames with closely interacting animals!** Therefore, manually selecting some frames is a good idea if interactions are not highly frequent in the video.
+```{TIP}
+For maDLC, **be sure you have labeled frames with closely interacting animals**!
+Therefore, manually selecting some frames is a good idea if interactions are not highly
+frequent in the video.
+```
-However, picking frames is highly dependent on the data and the behavior being studied. Therefore, it is hard to
-provide all purpose code that extracts frames to create a good training dataset for every behavior and animal. If the user feels specific frames are lacking, they can extract hand selected frames of interest using the interactive GUI
+However, picking frames is highly dependent on the data and the behavior being studied.
+Therefore, it is hard to provide all purpose code that extracts frames to create a good
+training dataset for every behavior and animal. If the user feels specific frames are
+lacking, they can extract hand selected frames of interest using the interactive GUI
provided along with the toolbox. This can be launched by using:
+
```python
deeplabcut.extract_frames(config_path, 'manual')
```
-The user can use the *Load Video* button to load one of the videos in the project configuration file, use the scroll
-bar to navigate across the video and *Grab a Frame* (or a range of frames, as of version 2.0.5) to extract the frame(s). The user can also look at the extracted frames and e.g. delete frames (from the directory) that are too similar before reloading the set and then manually annotating them.
+// FIXME(niels) - add a napari frame extractor description.
+The user can use the *Load Video* button to load one of the videos in the project
+configuration file, use the scroll bar to navigate across the video and *Grab a Frame*.
+The user can also look at the extracted frames and e.g. delete frames (from the
+directory) that are too similar before reloading the set and then manually annotating
+them.
+
+````{admonition} Click the button to see API Docs
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.extract_frames.rst
+```
+````
-### Label Frames:
+### (D) Label Frames
```python
deeplabcut.label_frames(config_path)
```
-As of 2.2 there is a new multi-animal labeling GUI (as long as in your `config.yaml` says `multianimalproject: true` at the top, this will automatically launch).
+The toolbox provides a function **label_frames** which helps the user to easily label
+all the extracted frames using an interactive graphical user interface (GUI). The user
+should have already named the bodyparts to label (points of interest) in the
+project’s configuration file by providing a list. The following command invokes the
+napari-deeplabcut labelling GUI.
-The toolbox provides a function **label_frames** which helps the user to easily label all the extracted frames using
-an interactive graphical user interface (GUI). The user should have already named the body parts to label (points of
-interest) in the project’s configuration file by providing a list. The following command invokes the labeling toolbox.
+[🎥 DEMO](https://youtu.be/hsA9IB5r73E)
-The user needs to use the *Load Frames* button to select the directory which stores the extracted frames from one of
-the videos. Subsequently, the user can use one of the radio buttons (top right) to select a body part to label. **RIGHT** click to add the label. Left click to drag the label, if needed. If you label a part accidentally, you can use the middle button on your mouse to delete (or hit the delete key while you hover over the point)! If you cannot see a body part in the frame, skip over the label! Please see the ``HELP`` button for more user instructions. This auto-advances once you labeled the first body part. You can also advance to the next frame by clicking on the RIGHT arrow on your keyboard (and go to a previous frame with LEFT arrow).
-Each label will be plotted as a dot in a unique color.
+HOT KEYS IN THE Labeling GUI (also see "help" in GUI):
-The user is free to move around the body part and once satisfied with its position, can select another radio button
-(in the top right) to switch to the respective body part (it otherwise auto-advances). The user can skip a body part if it is not visible. Once all the visible body parts are labeled, then the user can use ‘Next Frame’ to load the following frame. The user needs to save the labels after all the frames from one of the videos are labeled by clicking the save button at the bottom right. Saving the labels will create a labeled dataset for each video in a hierarchical data file format (HDF) in the
-subdirectory corresponding to the particular video in **labeled-data**. You can save at any intermediate step (even without closing the GUI, just hit save) and you return to labeling a dataset by reloading it!
+```
+Ctrl + C: Copy labels from previous frame.
+Keyboard arrows: advance frames.
+Delete key: delete label.
+```
-**CRITICAL POINT:** It is advisable to **consistently label similar spots** (e.g., on a wrist that is very large, try
-to label the same location). In general, invisible or occluded points should not be labeled by the user, unless you want to teach the network to "guess" - this is possible, but could affect accuracy. If you don't want/or don't see a bodypart, they can simply be skipped by not applying the label anywhere on the frame.
+
-OPTIONAL: In the event of adding more labels to the existing labeled dataset, the user need to append the new
-labels to the bodyparts in the config.yaml file. Thereafter, the user can call the function **label_frames**. A box will pop up and ask the user if they wish to display all parts, or only add in the new labels. Saving the labels after all the images are labelled will append the new labels to the existing labeled dataset.
+**CRITICAL POINT:** It is advisable to **consistently label similar spots** (e.g., on a
+wrist that is very large, try to label the same location). In general, invisible or
+occluded points should not be labeled by the user, unless you want to teach the network
+to "guess" - this is possible, but could affect accuracy. If you don't want/or don't see
+a bodypart, they can simply be skipped by not applying the label anywhere on the frame.
-**maDeepLabCut CRITICAL POINT:** For multi-animal labeling, unless you can tell apart the animals, you do not need to worry about the "ID" of each animal. For example: if you have a white and black mouse label the white mouse as animal 1, and black as animal 2 across all frames. If two black mice, then the ID label 1 or 2 can switch between frames - no need for you to try to identify them (but always label consistently within a frame). If you have 2 black mice but one always has an optical fiber (for example), then DO label them consistently as animal1 and animal_fiber (for example). The point of multi-animal DLC is to train models that can first group the correct bodyparts to individuals, then associate those points in a given video to a specific individual, which then also uses temporal information to link across the video frames.
+OPTIONAL: In the event of adding more labels to the existing labeled dataset, the user
+needs to append the new labels to the bodyparts in the config.yaml file. Thereafter, the
+user can call the function **label_frames**. A box will pop up and ask the user if they
+wish to display all parts, or only add in the new labels. Saving the labels after all
+the images are labelled will append the new labels to the existing labeled dataset.
-Note, we also highly recommend that you use more bodypoints that you might otherwise have (see the example below).
+**maDeepLabCut CRITICAL POINT:** For multi-animal labeling, unless you can tell apart
+the animals, you do not need to worry about the "ID" of each animal. For example: if you
+have a white and black mouse label the white mouse as animal 1, and black as animal 2
+across all frames. If two black mice, then the ID label 1 or 2 can switch between
+frames - no need for you to try to identify them (but always label consistently within a
+frame). If you have 2 black mice but one always has an optical fiber (for example), then
+DO label them consistently as animal1 and animal_fiber (for example). The point of
+multi-animal DLC is to train models that can first group the correct bodyparts to
+individuals, then associate those points in a given video to a specific individual,
+which then also uses temporal information to link across the video frames.
-**Example Labeling with maDeepLabCut:**
-- note you should within an animal be consistent, i.e., all bodyparts on mouse1 should be on mouse1, but across frames "mouse1" can be any of the black mice (as here it is nearly impossible to tell them apart visually). IF you can tell them apart, do label consistently!
+Note, we also highly recommend that you use more bodyparts that you might otherwise have
+(see the example below).
-
-
-
+For more information, checkout the [napari-deeplabcut docs](napari-gui) for
+more information about the labelling workflow.
-### Check Annotated Frames:
+### (E) Check Annotated Frames
Checking if the labels were created and stored correctly is beneficial for training, since labeling
is one of the most critical parts for creating the training dataset. The DeepLabCut toolbox provides a function
-‘check_labels’ to do so. It is used as follows:
+`check_labels` to do so. It is used as follows:
+
```python
deeplabcut.check_labels(config_path, visualizeindividuals=True/False)
```
+
**maDeepLabCut:** you can check and plot colors per individual or per body part, just set the flag `visualizeindividuals=True/False`. Note, you can run this twice in both states to see both images.
@@ -225,13 +315,33 @@ deeplabcut.check_labels(config_path, visualizeindividuals=True/False)
For each video directory in labeled-data this function creates a subdirectory with **labeled** as a suffix. Those directories contain the frames plotted with the annotated body parts. The user can double check if the body parts are labeled correctly. If they are not correct, the user can reload the frames (i.e. `deeplabcut.label_frames`), move them around, and click save again.
+````{admonition} Click the button to see API Docs
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.check_labels.rst
+```
+````
+
+### (F) Create Training Dataset
+
+At this point, you'll need to select your neural network type.
+
+For the **PyTorch engine**, please see [the PyTorch Model Architectures](
+dlc3-architectures) for options.
-### Create Training Dataset:
+For the **TensorFlow engine**, please see Lauer et al. 2021 for options. Multi-animal
+models will use `imgaug`, ADAM optimization, our new DLCRNet, and batch training. We
+suggest keeping these defaults at this time. At this step, the ImageNet pre-trained
+networks (i.e. ResNet-50) weights will be downloaded. If they do not download (you will
+see this downloading in the terminal, then you may not have permission to do so (
+something we have seen with some Windows users - see the **[
+WIKI troubleshooting for more help!](
+https://github.com/DeepLabCut/DeepLabCut/wiki/Troubleshooting-Tips)**).
-At this point you also select your neural network type. Please see Lauer et al. 2021 for options. For **create_multianimaltraining_dataset** we already changed this such that by default you will use imgaug, ADAM optimization, our new DLCRNet, and batch training. We suggest these defaults at this time. Then run:
+Then run:
```python
-deeplabcut.create_multianimaltraining_dataset(config_path)
+deeplabcut.create_training_dataset(config_path)
```
- The set of arguments in the function will shuffle the combined labeled dataset and split it to create train and test
@@ -242,47 +352,33 @@ keeps track of how often the dataset was refined).
- OPTIONAL: If the user wishes to benchmark the performance of the DeepLabCut, they can create multiple
training datasets by specifying an integer value to the `num_shuffles`; see the docstring for more details.
-- Each iteration of the creation of a training dataset will create several files, which is used by the feature detectors,
-and a ``.pickle`` file that contains the meta information about the training dataset. This also creates two subdirectories
-within **dlc-models** called ``test`` and ``train``, and these each have a configuration file called pose_cfg.yaml.
-Specifically, the user can edit the **pose_cfg.yaml** within the **train** subdirectory before starting the training. These
-configuration files contain meta information with regard to the parameters of the feature detectors. Key parameters
-are listed in Box 2.
-
-- At this step, the ImageNet pre-trained networks (i.e. ResNet-50) weights will be downloaded. If they do not download (you will see this downloading in the terminal, then you may not have permission to do so (something we have seen with some Windows users - see the **[WIKI troubleshooting for more help!](https://github.com/DeepLabCut/DeepLabCut/wiki/Troubleshooting-Tips)**).
-
-**OPTIONAL POINTS:**
-
-With the data-driven skeleton selection introduced in 2.2rc1+, DLC networks are trained by default
-on complete skeletons (i.e., they learn all possible redundant connections), before being optimally pruned
-at model evaluation. Although this procedure is by far superior to manually defining a graph,
-we leave manually-defining a skeleton as an option for the advanced user:
-
-```python
-my_better_graph = [[0, 1], [1, 2], [2, 3]] # These are indices in the list of multianimalbodyparts
-deeplabcut.create_multianimaltraining_dataset(config_path, paf_graph=my_better_graph)
-```
-
-Alternatively, the `skeleton` defined in the `config.yaml` file can also be used:
-
-```python
-deeplabcut.create_multianimaltraining_dataset(config_path, paf_graph='config')
-```
-
-Importantly, if a user-defined graph is used it still is required to cover all multianimalbodyparts at least once.
-
-**DATA AUGMENTATION:** At this stage you can also decide what type of augmentation to use. The default loaders work well for most all tasks (as shown on www.deeplabcut.org), but there are many options, more data augmentation, intermediate supervision, etc. Please look at the [**pose_cfg.yaml**](https://github.com/DeepLabCut/DeepLabCut/blob/master/deeplabcut/pose_cfg.yaml) file for a full list of parameters **you might want to change before running this step.** There are several data loaders that can be used. For example, you can use the default loader (introduced and described in the Nature Protocols paper), [TensorPack](https://github.com/tensorpack/tensorpack) for data augmentation (currently this is easiest on Linux only), or [imgaug](https://imgaug.readthedocs.io/en/latest/). We recommend `imgaug` (which is default now!). You can set this by passing:``` deeplabcut.create_training_dataset(config_path, augmenter_type='imgaug') ```
-
-The differences of the loaders are as follows:
-- `default`: our standard DLC 2.0 introduced in Nature Protocols variant (scaling, auto-crop augmentation) *will be renamed to `crop_scale` in a future release!*
-- `imgaug`: a lot of augmentation possibilities, efficient code for target map creation & batch sizes >1 supported. You can set the parameters such as the `batch_size` in the `pose_cfg.yaml` file for the model you are training.
-- `tensorpack`: a lot of augmentation possibilities, multi CPU support for fast processing, target maps are created less efficiently than in imgaug, does not allow batch size>1
-- `deterministic`: only useful for testing, freezes numpy seed; otherwise like default.
-
-Our recent [A Primer on Motion Capture with Deep Learning: Principles, Pitfalls, and Perspectives](https://www.cell.com/neuron/pdf/S0896-6273(20)30717-0.pdf), details the advantage of augmentation for a worked example (see Fig 7). TL;DR: use imgaug and use the symmetries of your data!
-
-
-Alternatively, you can set the loader (as well as other training parameters) in the **pose_cfg.yaml** file of the model that you want to train. Note, to get details on the options, look at the default file: [**pose_cfg.yaml**](https://github.com/DeepLabCut/DeepLabCut/blob/master/deeplabcut/pose_cfg.yaml).
+- Each iteration of the creation of a training dataset will create several files, which
+is used by the feature detectors, and a ``.pickle`` file that contains the meta
+information about the training dataset. This also creates two subdirectories within
+**dlc-models-pytorch** (**dlc-models** for the TensorFlow engine) called ``test`` and
+``train``, and these each have a configuration file called pose_cfg.yaml. Specifically,
+the user can edit the **pytorch_config.yaml** (**pose_cfg.yaml** for TensorFlow engine)
+within the **train** subdirectory before starting the training. These configuration
+files contain meta information with regard to the parameters of the feature detectors.
+Key parameters are listed in Box 2.
+
+**DATA AUGMENTATION:** At this stage you can also decide what type of augmentation to
+use. Once you've called `create_training_dataset`, you can edit the
+[**pytorch_config.yaml**](dlc3-pytorch-config) file that was created (or for the
+TensorFlow engine, the [**pose_cfg.yaml**](
+https://github.com/DeepLabCut/DeepLabCut/blob/master/deeplabcut/pose_cfg.yaml) file).
+
+- PyTorch Engine: [Albumentations](https://albumentations.ai/docs/) is used for data
+augmentation. Look at the [**pytorch_config.yaml**](dlc3-pytorch-config) for more
+information about image augmentation options.
+- TensorFlow Engine: The default augmentation works well for most tasks (as shown on
+www.deeplabcut.org), but there are many options, more data augmentation, intermediate
+supervision, etc. Only `imgaug` augmentation is available for multi-animal projects.
+
+[A Primer on Motion Capture with Deep Learning: Principles, Pitfalls, and Perspectives](
+https://www.cell.com/neuron/pdf/S0896-6273(20)30717-0.pdf), details the advantage of
+augmentation for a worked example (see Fig 8). TL;DR: use imgaug and use the symmetries
+of your data!
Importantly, image cropping as previously done with `deeplabcut.cropimagesandlabels` in multi-animal projects
is now part of the augmentation pipeline. In other words, image crops are no longer stored in labeled-data/..._cropped
@@ -292,118 +388,254 @@ In addition, one can specify a crop sampling strategy: crop centers can either b
As a reminder, cropping images into smaller patches is a form of data augmentation that simultaneously
allows the use of batch processing even on small GPUs that could not otherwise accommodate larger images + larger batchsizes (this usually increases performance and decreasing training time).
+**MODEL COMPARISON**: You can also test several models by creating the same train/test
+split for different networks.
+You can easily do this in the Project Manager GUI (by selecting the "Use an existing
+data split" option), which also lets you compare PyTorch and TensorFlow models.
-### Train The Network:
-
-```python
-deeplabcut.train_network(config_path, allow_growth=True)
-```
+````{versionadded} 3.0.0
+You can now create new shuffles using the same train/test split as
+existing shuffles with `create_training_dataset_from_existing_split`. This allows you to
+compare model performance (between different architectures or when using different
+training hyper-parameters) as the shuffles were trained on the same data, and evaluated
+on the same test data!
-The set of arguments in the function starts training the network for the dataset created for one specific shuffle. Note that you can change the loader (imgaug/default/etc) as well as other training parameters in the **pose_cfg.yaml** file of the model that you want to train (before you start training).
+Example usage - creating 3 new shuffles (with indices 10, 11 and 12) for a ResNet 50
+pose estimation model, using the same data split as was used for shuffle 0:
-Example parameters that one can call:
```python
-deeplabcut.train_network(config_path, shuffle=1, trainingsetindex=0, gputouse=None, max_snapshots_to_keep=5, autotune=False, displayiters=100, saveiters=15000, maxiters=30000, allow_growth=True)
+deeplabcut.create_training_dataset_from_existing_split(
+ config_path,
+ from_shuffle=0,
+ shuffles=[10, 11, 12],
+ net_type="resnet_50",
+)
```
+````
-By default, the pretrained networks are not in the DeepLabCut toolbox (as they are around 100MB each), but they get downloaded before you train. However, if not previously downloaded from the TensorFlow model weights, it will be downloaded and stored in a subdirectory *pre-trained* under the subdirectory *models* in *Pose_Estimation_Tensorflow*.
-At user specified iterations during training checkpoints are stored in the subdirectory *train* under the respective iteration directory.
-
-If the user wishes to restart the training at a specific checkpoint they can specify the full path of the checkpoint to
-the variable ``init_weights`` in the **pose_cfg.yaml** file under the *train* subdirectory (see Box 2).
-
-**CRITICAL POINT:** It is recommended to train the networks for thousands of iterations until the loss plateaus (typically around **500,000**) if you use batch size 1, and **50-100K** if you use batchsize 8 (the default).
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_dataset
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_training_dataset.rst
+```
+````
-If you use **maDeepLabCut** the recommended training iterations is **20K-100K** (it automatically stops at 200K!), as we use Adam and batchsize 8; if you have to reduce the batchsize for memory reasons then the number of iterations needs to be increased.
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_model_comparison
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_training_model_comparison.rst
+```
+````
-The variables ``display_iters`` and ``save_iters`` in the **pose_cfg.yaml** file allows the user to alter how often the loss is displayed and how often the weights are stored.
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_dataset_from_existing_split
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_training_dataset_from_existing_split.rst
+```
+````
-**maDeepLabCut CRITICAL POINT:** For multi-animal projects we are using not only different and new output layers, but also new data augmentation, optimization, learning rates, and batch training defaults. Thus, please use a lower ``save_iters`` and ``maxiters``. I.e. we suggest saving every 10K-15K iterations, and only training until 50K-100K iterations. We recommend you look closely at the loss to not overfit on your data. The bonus, training time is much less!!!
+### (G) Train The Network
-**Parameters:**
+```python
+deeplabcut.train_network(config_path, shuffle=1)
```
-config : string
- Full path of the config.yaml file as a string.
-shuffle: int, optional
- Integer value specifying the shuffle index to select for training. Default is set to 1
+The set of arguments in the function starts training the network for the dataset created
+for one specific shuffle. Note that you can change training parameters in the
+[**pytorch_config.yaml**](dlc3-pytorch-config) file (or **pose_cfg.yaml** for TensorFlow
+models) of the model that you want to train (before you start training).
+
+At user specified iterations during training checkpoints are stored in the subdirectory
+*train* under the respective iteration & shuffle directory.
-trainingsetindex: int, optional
- Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml).
+````{admonition} Tips on training models with the PyTorch Engine
+:class: dropdown
-gputouse: int, optional. Natural number indicating the number of your GPU (see number in nvidia-smi). If you do not have a GPU, put None.
-See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
+Example parameters that one can call:
-max_snapshots_to_keep: int, or None. Sets how many snapshots are kept, i.e. states of the trained network. For every saving iteration a snapshot is stored, however, only the last max_snapshots_to_keep many are kept! If you change this to None, then all are kept.
-See: https://github.com/DeepLabCut/DeepLabCut/issues/8#issuecomment-387404835
+```python
+deeplabcut.train_network(
+ config_path,
+ shuffle=1,
+ trainingsetindex=0,
+ device="cuda:0",
+ max_snapshots_to_keep=5,
+ displayiters=100,
+ save_epochs=5,
+ epochs=200,
+)
+```
+
+Pytorch models in DeepLabCut 3.0 are trained for a set number of epochs, instead of a
+maximum number of iterations (which is what was used for TensorFlow models). An epoch
+is a single pass through the training dataset, which means your model has seen each
+training image exactly once. So if you have 64 training images for your network, an
+epoch is 64 iterations with batch size 1 (or 32 iterations with batch size 2, 16 with
+batch size 4, etc.).
+
+By default, the pretrained networks are not in the DeepLabCut toolbox (as they can be
+more than 100MB), but they get downloaded automatically before you train.
+
+If the user wishes to restart the training at a specific checkpoint they can specify the
+full path of the checkpoint to the variable ``resume_training_from`` in the [
+**pytorch_config.yaml**](
+dlc3-pytorch-config) file (checkout the "Restarting Training at a Specific Checkpoint"
+section of the docs) under the *train* subdirectory.
+
+**CRITICAL POINT:** It is recommended to train the networks **until the loss plateaus**
+(depending on the dataset, model architecture and training hyper-parameters this happens
+after 100 to 250 epochs of training).
+
+The variables ``display_iters`` and ``save_epochs`` in the [**pytorch_config.yaml**](
+dlc3-pytorch-config) file allows the user to alter how often the loss is displayed
+and how often the weights are stored. We suggest saving every 5 to 25 epochs.
+````
-autotune: property of TensorFlow, somehow faster if 'false' (as Eldar found out, see https://github.com/tensorflow/tensorflow/issues/13317). Default: False
+````{admonition} Tips on training models with the TensorFlow Engine
+:class: dropdown
-displayiters: this variable is actually set in pose_config.yaml. However, you can overwrite it with this hack. Don't use this regularly, just if you are too lazy to dig out
-the pose_config.yaml file for the corresponding project. If None, the value from there is used, otherwise it is overwritten! Default: None
+Example parameters that one can call:
-saveiters: this variable is actually set in pose_config.yaml. However, you can overwrite it with this hack. Don't use this regularly, just if you are too lazy to dig out
-the pose_config.yaml file for the corresponding project. If None, the value from there is used, otherwise it is overwritten! Default: None
+```python
+deeplabcut.train_network(
+ config_path,
+ shuffle=1,
+ trainingsetindex=0,
+ gputouse=None,
+ max_snapshots_to_keep=5,
+ autotune=False,
+ displayiters=100,
+ saveiters=15000,
+ maxiters=30000,
+ allow_growth=True,
+)
+```
+
+By default, the pretrained networks are not in the DeepLabCut toolbox (as they are
+around 100MB each), but they get downloaded before you train. However, if not previously
+downloaded from the TensorFlow model weights, it will be downloaded and stored in a
+subdirectory *pre-trained* under the subdirectory *models* in
+*Pose_Estimation_Tensorflow*. At user specified iterations during training checkpoints
+are stored in the subdirectory *train* under the respective iteration directory.
+
+If the user wishes to restart the training at a specific checkpoint they can specify the
+full path of the checkpoint to the variable ``init_weights`` in the **pose_cfg.yaml**
+file under the *train* subdirectory (see Box 2).
+
+**CRITICAL POINT:** It is recommended to train the networks for thousands of iterations
+until the loss plateaus (typically around **500,000**) if you use batch size 1, and
+**50-100K** if you use batchsize 8 (the default).
+
+If you use **maDeepLabCut** the recommended training iterations is **20K-100K**
+(it automatically stops at 200K!), as we use Adam and batchsize 8; if you have to reduce
+ the batchsize for memory reasons then the number of iterations needs to be increased.
+
+The variables ``display_iters`` and ``save_iters`` in the **pose_cfg.yaml** file allows
+the user to alter how often the loss is displayed and how often the weights are stored.
+
+**maDeepLabCut CRITICAL POINT:** For multi-animal projects we are using not only
+different and new output layers, but also new data augmentation, optimization, learning
+rates, and batch training defaults. Thus, please use a lower ``save_iters`` and
+``maxiters``. I.e. we suggest saving every 10K-15K iterations, and only training until
+50K-100K iterations. We recommend you look closely at the loss to not overfit on your
+data. The bonus, training time is much less!!!
+````
-maxiters: This sets how many iterations to train. This variable is set in pose_config.yaml. However, you can overwrite it with this. If None, the value from there is used, otherwise it is overwritten! Default: None
+````{admonition} Click the button to see API Docs for train_network
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.train_network.rst
```
+````
-### Evaluate the Trained Network:
+### (H) Evaluate the Trained Network
+
+It is important to evaluate the performance of the trained network. This performance is
+measured by computing two metrics:
+
+- **Average root mean square error** (RMSE) between the manual labels and the ones
+predicted by your trained DeepLabCut model. The RMSE is proportional to the mean average
+Euclidean error (MAE) between the manual labels and the ones predicted by DeepLabCut.
+The MAE is displayed for all pairs and only likely pairs (>p-cutoff). This helps to
+exclude, for example, occluded body parts. One of the strengths of DeepLabCut is that
+due to the probabilistic output of the scoremap, it can, if sufficiently trained, also
+reliably report if a body part is visible in a given frame. (see discussions of finger
+tips in reaching and the Drosophila legs during 3D behavior in [Mathis et al, 2018]).
+- **Mean Average Precision** (mAP) and **Mean Average Recall** (mAR) for the individuals
+predicted by your trained DeepLabCut model. This metric describes the precision of your
+model, based on a considered definition of what a correct detection of an individual is.
+It isn't as useful for single-animal models, as RMSE does a great job of evaluating your
+model in that case.
+
+```{admonition} A more detailed description of mAP and mAR
+:class: dropdown
+
+For multi-animal pose estimation, multiple predictions can be made for each image.
+We want to get some idea of the proportion of correct predictions among all predictions
+that are made.
+However, the notion of "correct prediction" for pose estimation is not straightforward:
+is a prediction correct if all predicted keypoints are within 5 pixels of the ground
+truth? Within 2 pixels of the ground truth? What if all pixels but one match the ground
+truth perfectly, but the wrong prediction is 50 pixels away? Mean average precision (
+and mean average recall) estimate the precision/recall of your models by setting
+different "thresholds of correctness" and averaging results. How "correct" a
+prediction is can be evaluated through [object-keypoint similarity](
+https://cocodataset.org/#keypoints-eval).
+
+A good resource to get a deeper understanding of mAP is the [Stanford CS230 course](
+https://cs230.stanford.edu/section/8/#object-detection-iou-ap-and-map). While it
+describes mAP for object detection (where bounding boxes are predicted instead of
+keypoints), the same metric can be computed for pose estimation, where similarity
+between predictions and ground truth is computed through [object-keypoint similarity](
+https://cocodataset.org/#keypoints-eval) instead of intersection-over-union (IoU).
+```
+
+It's also important to visually inspect predictions on individual frames to assess the
+performance of your model. You can do this by setting `plotting=True` when you call
+`evaluate_network`. The evaluation results are computed by typing:
-Here, for traditional projects you will get a pixel distance metric and you should inspect the individual frames:
```python
-deeplabcut.evaluate_network(config_path, plotting=True)
+deeplabcut.evaluate_network(config_path, Shuffles=[1], plotting=True)
```
-:movie_camera:[VIDEO TUTORIAL AVAILABLE!](https://www.youtube.com/watch?v=bgfnz1wtlpo)
-It is important to evaluate the performance of the trained network. This performance is measured by computing
-the mean average Euclidean error (MAE; which is proportional to the average root mean square error) between the
-manual labels and the ones predicted by DeepLabCut. The MAE is saved as a comma separated file and displayed
-for all pairs and only likely pairs (>p-cutoff). This helps to exclude, for example, occluded body parts. One of the
-strengths of DeepLabCut is that due to the probabilistic output of the scoremap, it can, if sufficiently trained, also
-reliably report if a body part is visible in a given frame. (see discussions of finger tips in reaching and the Drosophila
-legs during 3D behavior in [Mathis et al, 2018]). The evaluation results are computed by typing:
+🎥 [VIDEO TUTORIAL AVAILABLE!](https://www.youtube.com/watch?v=bgfnz1wtlpo)
Setting ``plotting`` to True plots all the testing and training frames with the manual and predicted labels; these will
-be colored by body part type by default. They can alternatively be colored by individual by passing `plotting`=`individual`.
+be colored by body part type by default. They can alternatively be colored by individual by passing `plotting="individual"`.
The user should visually check the labeled test (and training) images that are created in the ‘evaluation-results’ directory.
Ideally, DeepLabCut labeled unseen (test images) according to the user’s required accuracy, and the average train
and test errors are comparable (good generalization). What (numerically) comprises an acceptable MAE depends on
many factors (including the size of the tracked body parts, the labeling variability, etc.). Note that the test error can
also be larger than the training error due to human variability (in labeling, see Figure 2 in Mathis et al, Nature Neuroscience 2018).
-**Optional parameters:**
-```
- Shuffles: list, optional -List of integers specifying the shuffle indices of the training dataset. The default is [1]
-
- plotting: bool, optional -Plots the predictions on the train and test images. The default is `False`; if provided it must be either `True` or `False`
-
- show_errors: bool, optional -Display train and test errors. The default is `True`
-
- comparisonbodyparts: list of bodyparts, Default is all -The average error will be computed for those body parts only (Has to be a subset of the body parts).
-
- gputouse: int, optional -Natural number indicating the number of your GPU (see number in nvidia-smi). If you do not have a GPU, put None. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
-```
-
The plots can be customized by editing the **config.yaml** file (i.e., the colormap, scale, marker size (dotsize), and
-transparency of labels (alphavalue) can be modified). By default each body part is plotted in a different color
+transparency of labels (alpha-value) can be modified). By default each body part is plotted in a different color
(governed by the colormap) and the plot labels indicate their source. Note that by default the human labels are
-plotted as plus (‘+’), DeepLabCut’s predictions either as ‘.’ (for confident predictions with likelihood > p-cutoff) and
+plotted as plus (‘+’), DeepLabCut’s predictions either as ‘.’ (for confident predictions with likelihood > `pcutoff`) and
’x’ for (likelihood <= `pcutoff`).
-The evaluation results for each shuffle of the training dataset are stored in a unique subdirectory in a newly created
-directory ‘evaluation-results’ in the project directory. The user can visually inspect if the distance between the labeled
-and the predicted body parts are acceptable. In the event of benchmarking with different shuffles of same training
-dataset, the user can provide multiple shuffle indices to evaluate the corresponding network. If the generalization is
-not sufficient, the user might want to:
+The evaluation results for each shuffle of the training dataset are stored in a unique
+subdirectory in a newly created directory ‘evaluation-results-pytorch’ (or
+‘evaluation-results’ for TensorFlow models) in the project directory.
+The user can visually inspect if the distance between the labeled and the predicted body
+parts are acceptable. In the event of benchmarking with different shuffles of same training
+dataset, the user can provide multiple shuffle indices to evaluate the corresponding
+network. If the generalization is not sufficient, the user might want to:
-• check if the labels were imported correctly; i.e., invisible points are not labeled and the points of interest are
-labeled accurately
+• check if the labels were imported correctly; i.e., invisible points are not labeled
+and the points of interest are labeled accurately
• make sure that the loss has already converged
• consider labeling additional images and make another iteration of the training data set
+````{admonition} Click the button to see API Docs for evaluate_network
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.evaluate_network.rst
+```
+````
+
**maDeepLabCut: (or on normal projects!)**
In multi-animal projects, model evaluation is crucial as this is when
@@ -416,9 +648,10 @@ You should also plot the scoremaps, locref layers, and PAFs to assess performanc
```python
deeplabcut.extract_save_all_maps(config_path, shuffle=shuffle, Indices=[0, 5])
```
-you can drop "Indices" to run this on all training/testing images (this is very slow!)
-### Evaluating your network on videos
+You can drop "Indices" to run this on all training/testing images (this is very slow!)
+
+### (I) Analyze new Videos
**-------------------- DECISION POINT -------------------**
@@ -428,12 +661,28 @@ you can drop "Indices" to run this on all training/testing images (this is very
Please run:
```python
-scorername = deeplabcut.analyze_videos(config_path,['/fullpath/project/videos/testVideo.mp4'], videotype='.mp4')
-deeplabcut.create_video_with_all_detections(config_path, ['/fullpath/project/videos/testVideo.mp4'], videotype='.mp4')
+videos_to_analyze = ['/fullpath/project/videos/testVideo.mp4']
+scorername = deeplabcut.analyze_videos(config_path, videos_to_analyze, videotype='.mp4')
+deeplabcut.create_video_with_all_detections(config_path, videos_to_analyze, videotype='.mp4')
```
-Please note that you do **not** get the .h5/csv file you might be used to getting (this comes after tracking). You will get a `pickle` file that is used in `create_video_with_all_detections`.
-Another sanity check may be to examine the distributions of edge affinity costs using `deeplabcut.utils.plot_edge_affinity_distributions`. Easily separable distributions indicate that the model has learned strong links to group keypoints into distinct individuals — likely a necessary feature for the assembly stage (note that the amount of overlap will also depend on the amount of interactions between your animals in the daset).
-IF you have good clean out video, ending in `....full.mp4` (and the evaluation metrics look good, scoremaps look good, plotted evaluation images, and affinity distributions are far apart for most edges), then go forward!!!
+
+Please note that you do **not** get the .h5/csv file you might be used to getting (this
+comes after tracking). You will get a `pickle` file that is used in
+`create_video_with_all_detections`.
+
+For models predicting part-affinity fields, another sanity check may be to
+examine the distributions of edge affinity costs using `deeplabcut.utils.plot_edge_affinity_distributions`. Easily separable distributions
+indicate that the model has learned strong links to group keypoints into distinct
+individuals — likely a necessary feature for the assembly stage (note that the amount of
+overlap will also depend on the amount of interactions between your animals in the
+dataset). All TensorFlow multi-animal models use part-affinity fields and PyTorch models
+consisting of just a backbone name (e.g. `resnet_50`, `resnet_101`) use part-affinity
+fields. If you're unsure whether your PyTorch model has a one, check
+the **pytorch_config.yaml** for a `DLCRNetHead`.
+
+IF you have good clean out video, ending in `....full.mp4` (and the evaluation metrics
+look good, scoremaps look good, plotted evaluation images, and affinity distributions
+are far apart for most edges), then go forward!!!
If this does not look good, we recommend extracting and labeling more frames (even from more videos). Try to label close interactions of animals for best performance. Once you label more, you can create a new training set and train.
@@ -441,30 +690,53 @@ You can either:
1. extract more frames manually from existing or new videos and label as when initially building the training data set, or
2. let DeepLabCut find frames where keypoints were poorly detected and automatically extract those for you. All you need is
to run:
+
```python
deeplabcut.find_outliers_in_raw_data(config_path, pickle_file, video_file)
```
+
where pickle_file is the `_full.pickle` one obtains after video analysis.
Flagged frames will be added to your collection of images in the corresponding labeled-data folders for you to label.
-## Animal Assembly and Tracking across frames
+### Animal Assembly and Tracking across frames
-After pose estimation, now you perform assembly and tracking. *NEW* in 2.2 is a novel data-driven way to set the optimal skeleton and assembly metrics, so this no longer requires user input. The metrics, in case you do want to edit them, can be found in the `inference_cfg.yaml` file.
+After pose estimation, now you perform assembly and tracking.
+
+````{versionadded} v2.2.0
+*NEW* in 2.2 is a novel data-driven way to set the optimal skeleton and assembly
+metrics, so this no longer requires user input. The metrics, in case you do want to edit
+them, can be found in the `inference_cfg.yaml` file.
+````
### Optimized Animal Assembly + Video Analysis:
-- Please note that **novel videos DO NOT need to be added to the config.yaml file**. You can simply have a folder elsewhere on your computer and pass the video folder (then it will analyze all videos of the specified type (i.e. ``videotype='.mp4'``), or pass the path to the **folder** or exact video(s) you wish to analyze:
+Please note that **novel videos DO NOT need to be added to the config.yaml file**. You
+can simply have a folder elsewhere on your computer and pass the video folder (then it
+will analyze all videos of the specified type (i.e. ``videotype='.mp4'``), or pass the
+path to the **folder** or exact video(s) you wish to analyze:
```python
deeplabcut.analyze_videos(config_path, ['/fullpath/project/videos/'], videotype='.mp4', auto_track=True)
```
+
#### IF auto_track = True:
-- *NEW* in 2.2.0.3+: `deeplabcut.analyze_videos` has a new argument `auto_track=True`, chaining pose estimation, tracking, and stitching in a single function call with defaults we found to work well. Thus, you'll now get the `.h5` file you might be used to getting in standard DLC. If `auto_track=False`, one must run `convert_detections2tracklets` and `stitch_tracklets` manually (see below), granting more control over the last steps of the workflow (ideal for advanced users).
+```{versionadded} v2.2.0.3
+A new argument `auto_track=True`, was added to `deeplabcut.analyze_videos` chaining pose
+estimation, tracking, and stitching in a single function call with defaults we found to
+work well. Thus, you'll now get the `.h5` file you might be used to getting in standard
+DLC. If `auto_track=False`, one must run `convert_detections2tracklets` and
+`stitch_tracklets` manually (see below), granting more control over the last steps of
+the workflow (ideal for advanced users).
+```
#### IF auto_track = False:
- - You can validate the tracking parameters. Namely, you can iteratively change the parameters, run `convert_detections2tracklets` then load them in the GUI (`refine_tracklets`) if you want to look at the performance. If you want to edit these, you will need to open the `inference_cfg.yaml` file (or click button in GUI). The options are:
+You can validate the tracking parameters. Namely, you can iteratively change the
+parameters, run `convert_detections2tracklets` then load them in the GUI
+(`refine_tracklets`) if you want to look at the performance. If you want to edit these,
+you will need to open the `inference_cfg.yaml` file (or click button in GUI). The
+options are:
```python
# Tracking:
@@ -503,8 +775,13 @@ from the final h5 file as was customary in single animal projects.
**Next, tracklets are stitched to form complete tracks with:
```python
-deeplabcut.stitch_tracklets(config_path, ['videofile_path'], videotype='mp4',
- shuffle=1, trainingsetindex=0)
+deeplabcut.stitch_tracklets(
+ config_path,
+ ['videofile_path'],
+ videotype='mp4',
+ shuffle=1,
+ trainingsetindex=0,
+)
```
Note that the base signature of the function is identical to `analyze_videos` and `convert_detections2tracklets`.
@@ -515,15 +792,42 @@ can be directly specified as follows:
```python
deeplabcut.stitch_tracklets(..., n_tracks=n)
```
+
In such cases, file columns will default to dummy animal names (ind1, ind2, ..., up to indn).
+#### API Docs
+
+````{admonition} Click the button to see API Docs for analyze_videos
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.analyze_videos.rst
+```
+````
+
+````{admonition} Click the button to see API Docs for convert_detections2tracklets
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.convert_detections2tracklets.rst
+```
+````
+
+````{admonition} Click the button to see API Docs for stitch_tracklets
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.stitch_tracklets.rst
+```
+````
+
### Using Unsupervised Identity Tracking:
-In Lauer et al. 2022 we introduced a new method to do unsupervised reID of animals. Here, you can use the tracklets to learn the identity of animals to enhance your tracking performance. To use the code:
+In Lauer et al. 2022 we introduced a new method to do unsupervised reID of animals.
+Here, you can use the tracklets to learn the identity of animals to enhance your
+tracking performance. To use the code:
```python
deeplabcut.transformer_reID(config, videos_to_analyze, n_tracks=None, videotype="mp4")
```
+
Note you should pass the n_tracks (number of animals) you expect to see in the video.
### Refine Tracklets:
@@ -549,7 +853,7 @@ Short demo:
-### Once you have analyzed video data (and refined your maDeepLabCut tracklets):
+### (J) Filter Pose Data
Firstly, Here are some tips for scaling up your video analysis, including looping over many folders for batch processing: https://github.com/DeepLabCut/DeepLabCut/wiki/Batch-Processing-your-Analysis
@@ -559,7 +863,14 @@ deeplabcut.filterpredictions(config_path,['/fullpath/project/videos/reachingvide
```
Note, this creates a file with the ending filtered.h5 that you can use for further analysis. This filtering step has many parameters, so please see the full docstring by typing: ``deeplabcut.filterpredictions?``
-### Plotting Results:
+````{admonition} Click the button to see API Docs
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.filterpredictions.rst
+```
+````
+
+### (K) Plot Trajectories , (L) Create Labeled Videos
- **NOTE :bulb::mega::** Before you create a video, you should set what threshold to use for plotting. This is set in the `config.yaml` file as `pcutoff` - if you have a well trained network, this should be high, i.e. set it to `0.8` or higher! IF YOU FILLED IN GAPS, you need to set this to `0` to "see" the filled in parts.
@@ -579,6 +890,20 @@ Create videos:
(more details [here](functionDetails.md#i-video-analysis-and-plotting-results))
+````{admonition} Click the button to see API Docs for plot_trajectories
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.plot_trajectories.rst
+```
+````
+
+````{admonition} Click the button to see API Docs for create_labeled_video
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_labeled_video.rst
+```
+````
+
### HELP:
In ipython/Jupyter notebook:
diff --git a/docs/pytorch/architectures.md b/docs/pytorch/architectures.md
index 5a14c85199..53dfaca621 100644
--- a/docs/pytorch/architectures.md
+++ b/docs/pytorch/architectures.md
@@ -10,22 +10,32 @@ from deeplabcut.pose_estimation_pytorch import available_models
print(available_models())
```
-## Backbones (neural networks)
+You can see a list of supported object detection architectures/variants by using:
-Several backbones are currently implemented in DeepLabCut PyTorch (more will come, and you can add more easily in our new model registry).
+```python
+from deeplabcut.pose_estimation_pytorch import available_detectors
+print(available_detectors())
+```
+
+## Neural Networks Architectures
+
+Several architectures are currently implemented in DeepLabCut PyTorch (more will come,
+and you can add more easily in our new model registry).
**ResNets**
- Adapted from [He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on Computer Vision and Pattern Recognition. 2016.](https://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html) and [Insafutdinov, Eldar et al. "DeeperCut: A Deeper, Stronger, and Faster Multi-Person Pose Estimation Model". European Conference on Computer Vision (ECCV) 2016.]
-- Current variants are `resnet_50`, `resnet_101`, `top_down_resnet_101`, `top_down_resnet_50`
+- Current bottom-up variants are `resnet_50`, `resnet_101`
+- Current top-down variants are `top_down_resnet_101`, `top_down_resnet_50`
**HRNet**
- Adapted from [Wang, Jingdong, et al. "Deep high-resolution representation learning for visual recognition." IEEE transactions on pattern analysis and machine intelligence 43.10 (2020): 3349-3364.](https://arxiv.org/abs/1908.07919)
-- Current variants are `hrnet_w18`, `hrnet_w32`, `hrnet_w48`, `top_down_hrnet_w18`, `top_down_hrnet_w32`, `top_down_hrnet_w48`
+- Current variants are `hrnet_w18`, `hrnet_w32`, `hrnet_w48`,
+- Current top-down variants are `top_down_hrnet_w18`, `top_down_hrnet_w32`, `top_down_hrnet_w48`
- Slower but typically more powerful than ResNets
**DEKR**
- Adapted from [Geng, Zigang et al. "Bottom-Up Human Pose Estimation Via Disentangled Keypoint Regression." Proceedings of the IEEE conference on Computer Vision and Pattern Recognition. 2021.](https://openaccess.thecvf.com/content/CVPR2021/papers/Geng_Bottom-Up_Human_Pose_Estimation_via_Disentangled_Keypoint_Regression_CVPR_2021_paper.pdf)
-- This model uses HRNet as a backbone. It learns to predict the center of each animal, and predicts the offset between each animal center and their keypoints
+- This model is a bottom-up model using HRNet as a backbone. It learns to predict the center of each animal, and predicts the offset between each animal center and their keypoints
- Current variants that are implemented (from smallest to largest): `dekr_w18`, `dekr_w32`, `dekr_w48`
- Note, this is a powerful multi-animal model but very heavy (slow)
@@ -40,6 +50,10 @@ Several backbones are currently implemented in DeepLabCut PyTorch (more will co
- This model uses a multi-scale variant of a ResNet as a backbone, and part-affinity fields to assemble individuals
- Variants: `dlcrnet_stride16_ms5`, `dlcrnet_stride32_ms5`
+**RTMPose**
+- From [Jiang, Tao et al. "RTMPose: Real-Time Multi-Person Pose Estimation based on MMPose"](https://arxiv.org/abs/2303.07399)
+- Top-down pose estimation model using a fast CSPNeXt backbone with a SimCC-style head
+- Variants: `rtmpose_s`, `rtmpose_m`, `rtmpose_x`
**AnimalTokenPose**
- Adapted from [Li, Yanjie, et al. "Tokenpose: Learning keypoint tokens for human pose estimation." Proceedings of the IEEE/CVF International conference on computer vision. 2021.](https://arxiv.org/abs/2104.03516) as in Ye et al. "SuperAnimal pretrained pose estimation models for behavioral analysis." Nature Communications. 2024](https://arxiv.org/abs/2203.07436)
@@ -78,8 +92,9 @@ individual). As localization of individuals is handled by the detector, this sim
the pose task to single-animal pose estimation!
Hence any single-animal model can be transformed into a top-down, multi-animal model. To
-do so, simply prefix `top_down` to your single-animal model name. Currently only a
-single FasterRCNN variant is available as a detector. Other variants will be added soon!
+do so, simply prefix `top_down` to your single-animal model name. Currently, the
+following detectors are available: `ssdlite`, `fasterrcnn_mobilenet_v3_large_fpn`,
+`fasterrcnn_resnet50_fpn_v2`. Other variants will be added soon!
The pose model for top-down nets is simply the backbone followed by a single convolution
for pose estimation. It's also possible to add deconvolutional layers to top-down model
diff --git a/docs/pytorch/pytorch_config.md b/docs/pytorch/pytorch_config.md
index a1ce5768c9..13e6244ea0 100644
--- a/docs/pytorch/pytorch_config.md
+++ b/docs/pytorch/pytorch_config.md
@@ -26,6 +26,7 @@ runner: # configuring the runner used for training
train_settings: # generic training settings, such as batch size and maximum number of epochs
...
logger: # optional: the configuration for a logger if you want one
+resume_training_from: # optional: restart the training at the specific checkpoint
```
## Sections
@@ -60,60 +61,21 @@ data:
normalize_images: true # this should always be set to true
train:
affine:
- p: 0.9
+ p: 0.5
rotation: 30
- translation: 40
- collate: # rescales the images when putting them in a batch
- type: ResizeFromDataSizeCollate
- min_scale: 0.4
- max_scale: 1.0
- min_short_side: 128
- max_short_side: 1152
- multiple_of: 32
- to_square: false
+ scaling: [0.5, 1.25]
+ translation: 0
covering: true
+ crop_sampling:
+ width: 448 # if your images are very small or very large, you may need to edit!
+ height: 448 # see below for more information about crop_sampling!
+ max_shift: 0.1
+ method: hybrid
gaussian_noise: 12.75
- hist_eq: true
motion_blur: true
normalize_images: true # this should always be set to true
```
-One of the most important elements is the `collate` configuration. If all images in your
-dataset have the same size, then it doesn't necessarily need to be added (but might
-still be beneficial). But if you have images of different sizes, then you'll need to
-define a way of "combining" these images into a single tensor of shape `(B, 3, H, W)`.
-The default way to do this is to use the `ResizeFromDataSizeCollate` collate function
-(other collate functions are defined in
-`deeplabcut/pose_estimation_pytorch/data/collate.py`). For each batch to collate, this
-implementation:
-1. Selects the target width & height all images will be resized to by getting the size
-of the first image in the batch, and multiplying it by a scale sampled uniformly at
-random from `(min_scale, max_scale)`.
-2. Resizes all images in the batch (while preserving their aspect ratio) such that they
-are the smallest size such that the target size fits entirely in the image.
-3. Crops each resulting image into the target size with a random crop.
-
-**Collate**: Defines how images are collated into batches.
-
-```yaml
-collate: # rescales the images when putting them in a batch
- type: ResizeFromDataSizeCollate # You can also use `ResizeFromListCollate`
- max_shift: 10 # the maximum shift, in pixels, to add to the random crop (this means
- # there can be a slight border around the image)
- max_size: 1024 # the maximum size of the long edge of the image when resized. If the
- # longest side will be greater than this value, resizes such that the longest side
- # is this size, and the shortest side is smaller than the desired size. This is
- # useful to keep some information from images with extreme aspect ratios.
- min_scale: 0.4 # the minimum scale to resize the image with
- max_scale: 1.0 # the maximum scale to resize the image with
- min_short_side: 128 # the minimum size of the target short side
- max_short_side: 1152 # the maximum size of the target short side
- multiple_of: 32 # pads the target height, width such that they are multiples of 32
- to_square: false # instead of using the aspect ratio of the first image, only the
- # short side of the first image will be used to sample a "side", and the images will
- # be cropped in squares
-```
-
The following transformations are available for the `train` and `inference` keys.
**Affine**: Applies an affine (rotation, translation, scaling) transformation to the
@@ -165,16 +127,21 @@ the noise will be set as 12.75).
gaussian_noise: 12.75 # bool, float: add gaussian noise
```
-
**Horizontal Flips**: This flips the image horizontally around the y-axis. As the
resulting image is mirrored, it does not preserve labels (the left hand would become the
-right hand, and vice-versa). This augmentation should not be used for pose models if you
-have symmetric keypoints! However, it is safe to use it to train detectors.
+right hand, and vice versa). This augmentation should not be used for pose models if you
+have symmetric keypoints! However, it is safe to use it to train detectors. If you want
+to use horizontal flips with symmetric keypoints, you need to specify them through the
+`symmetries` parameter!
```yaml
-# if float > 0, the probability of applying a horizontal flip
-# if true, applies a horizontal flip with probability 0.5
-hflip: true # bool, float
+# augmentation for object detectors or when no symmetric (left/right) keypoints exist:
+hflip: true
+
+# augmentation if your bodyparts are [snout, eye_L, eye_R, ear_L, ear_R]
+hflip:
+ p: 0.5 # apply a horizontal flip with 50% probability
+ symmetries: [[1, 2], [3, 4]] # the indices of symmetric keypoints
```
**Histogram Equalization**: Applies histogram equalization with probability 50%.
@@ -189,33 +156,92 @@ hist_eq: true # bool: whether to apply histogram equalization
motion_blur: true # bool: whether to apply motion blur
```
-**Normalization**
+**Normalization**: This should always be set to `true`.
```yaml
normalize_images: true # normalizes images
```
-**Resizing**: Resizes the images while preserving the aspect ratio (first resizes to the
-maximum possible size, then adds padding for the missing pixels).
+#### Dealing with Variable Image Sizes
-```yaml
-resize:
- height: 640 # int: the height to which all images will be resized
- width: 480 # int: the width to which all images will be resized
- keep_ratio: true # bool: the
+```{NOTE}
+When training with batch size 1 (or if all images in your dataset have the same size),
+you don't need to worry about any of this! However, you can still use `crop_sampling`
+which may help your model generalize.
```
+When training with a batch size greater than 1, all images in a batch **must** have the
+same size. PyTorch **collates** all images into one tensor of shape `[b, c, h, w]`,
+where `b` is the batch size, `c` the number of channels in the image, `h` and `w` the
+height and width of images in the batches. There are a few different ways to ensure that
+all images in a batch have the same size:
+
+1. **Crop sampling**. This is the default behavior for the PyTorch engine in DeepLabCut.
+A part of each image (of a fixed size) is cropped and given to the model to train. See
+below for more information.
+2. **A custom collate function**. Collate functions define a way that images of different
+sizes can be combined into one tensor. This involves resizing and padding images to the
+same size and aspect ratio. Available collate functions are defined in
+`deeplabcut/pose_estimation_pytorch/data/collate.py`.
+3. **Resizing all images**. All images can simply be resized to the same size. This
+usually doesn't lead to the best performance.
+
**Resizing - Crop Sampling**: An alternative way to ensure all images have the same size
is through cropping. The `crop_sampling` crops images down to a maximum width and
height, with options to sample the center of the crop according to the positions of the
-keypoints.
+keypoints. The methods to sample the center of the crop are as follows:
+
+- `uniform`: randomly over the image
+- `keypoints`: randomly over the annotated keypoints
+- `density`: weighing preferentially dense regions of keypoints
+- `hybrid`: alternating randomly between `uniform` and `density`
```yaml
crop_sampling:
height: 400 # int: the height of the crop
width: 400 # int: the height of the crop
max_shift: 0.4 # float: maximum allowed shift of the cropping center position as a fraction of the crop size.
- method: hybrid # str: how to sample the center of crops (one of 'uniform', 'keypoints', 'density', 'hybrid')
+ method: hybrid # str: the center sampling method (one of 'uniform', 'keypoints', 'density', 'hybrid')
+```
+
+**Collate**: Defines how images are collated into batches. The default way collate
+function to use is `ResizeFromDataSizeCollate` (other collate functions are defined in
+`deeplabcut/pose_estimation_pytorch/data/collate.py`). For each batch to collate, this
+implementation:
+1. Selects the target width & height all images will be resized to by getting the size
+of the first image in the batch, and multiplying it by a scale sampled uniformly at
+random from `(min_scale, max_scale)`.
+2. Resizes all images in the batch (while preserving their aspect ratio) such that they
+are the smallest size such that the target size fits entirely in the image.
+3. Crops each resulting image into the target size with a random crop.
+
+```yaml
+collate: # rescales the images when putting them in a batch
+ type: ResizeFromDataSizeCollate # You can also use `ResizeFromListCollate`
+ max_shift: 10 # the maximum shift, in pixels, to add to the random crop (this means
+ # there can be a slight border around the image)
+ max_size: 1024 # the maximum size of the long edge of the image when resized. If the
+ # longest side will be greater than this value, resizes such that the longest side
+ # is this size, and the shortest side is smaller than the desired size. This is
+ # useful to keep some information from images with extreme aspect ratios.
+ min_scale: 0.4 # the minimum scale to resize the image with
+ max_scale: 1.0 # the maximum scale to resize the image with
+ min_short_side: 128 # the minimum size of the target short side
+ max_short_side: 1152 # the maximum size of the target short side
+ multiple_of: 32 # pads the target height, width such that they are multiples of 32
+ to_square: false # instead of using the aspect ratio of the first image, only the
+ # short side of the first image will be used to sample a "side", and the images will
+ # be cropped in squares
+```
+
+**Resizing**: Resizes the images while preserving the aspect ratio (first resizes to the
+maximum possible size, then adds padding for the missing pixels).
+
+```yaml
+resize:
+ height: 640 # int: the height to which all images will be resized
+ width: 480 # int: the width to which all images will be resized
+ keep_ratio: true # bool: whether the aspect ratio should be preserved when resizing
```
### Model
@@ -293,6 +319,9 @@ runner:
max_snapshots: 5 # the maximum number of snapshots to save (the "best" model does not count as one of them)
save_epochs: 25 # the interval between each snapshot save
save_optimizer_state: false # whether the optimizer state should be saved with the model snapshots (very little reason to set to true)
+ gpus: # GPUs to use to train the network
+ - 0
+ - 1
```
**Key metric**: Every time the model is evaluated on the test set, metrics are computed
@@ -351,6 +380,44 @@ milestone to the corresponding values in `lr_list`. Examples:
gamma: 0.1
```
+You can also use schedulers that use other schedulers as parameters, such as a
+[`ChainedScheduler`](
+https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ChainedScheduler.html)
+or a [`SequentialLR`](
+https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.SequentialLR.html).
+
+The `SequentialLR` can be particularly useful, such as to use a first scheduler for some
+warmup epochs, and a second scheduler later. An example usage would be:
+
+```yaml
+ # Multiply the learning rate by `factor` for the first `total_iters` epochs
+ # After 5 epochs, start decaying the learning rate by `gamma` every `step_size` epochs
+ # If the initial learning rate is set to 1, the learning rates will be:
+ # epoch 0: 0.01 - using ConstantLR
+ # epoch 1: 0.01 - using ConstantLR
+ # epoch 2: 1.0 - using ConstantLR
+ # epoch 3: 1.0 - using ConstantLR
+ # epoch 4: 1.0 - using ConstantLR
+ # epoch 5: 1.0 - using StepLR
+ # epoch 6: 1.0 - using StepLR
+ # epoch 7: 0.1 - using StepLR
+ # epoch 8: 0.1 - using StepLR
+ scheduler:
+ type: SequentialLR
+ params:
+ schedulers:
+ - type: ConstantLR
+ params:
+ factor: 0.01
+ total_iters: 2
+ - type: StepLR
+ params:
+ step_size: 2
+ gamma: 0.1
+ milestones:
+ - 5
+```
+
### Train Settings
The `train_settings` key contains parameters that are specific to training. For more
@@ -456,6 +523,7 @@ detector:
...
train_settings: # detector train settings (same keys as for the pose model)
...
+ resume_training_from: # optional: restart the training at the specific checkpoint
```
Currently, the only detectors available are `FasterRCNN` and `SSDLite`. However, multiple variants of
diff --git a/docs/pytorch/user_guide.md b/docs/pytorch/user_guide.md
index d33b8e5cf4..6bbb4f289d 100644
--- a/docs/pytorch/user_guide.md
+++ b/docs/pytorch/user_guide.md
@@ -1,42 +1,34 @@
(dlc3-user-guide)=
-# DeepLabCut 3.0 - Pytorch User Guide
+# DeepLabCut 3.0 - PyTorch User Guide
## Using DeepLabCut 3.0
**DeepLabCut 3.0 keeps the same high-level API that you know, but has a full new PyTorch backend.
-Moreover, it is a rewrite that more developer friendly, more powerful, and built for modern deep
+Moreover, it is a rewrite that is more developer friendly, more powerful, and built for modern deep
learning-based computer vision applications.**
-**NOTE**🔥: We suggest that if you're just starting with DeepLabCut you start with the PyTorch backend.
-You will easily know which "engine" you are using by looking at the main `config.yaml` file, or top right corner in the GUI.
-If you have DeepLabCut projects in TensorFlow, we've got you covered too: you can seamlessly switch to train your already labeled data
-by simply switching the engine (and thereby also compare performance). In short, expect a boost 🔥.
+**NOTE**🔥: We suggest that if you're just starting with DeepLabCut you start with the
+PyTorch backend. You will easily know which "engine" you are using by looking at the
+main `config.yaml` file, or top right corner in the GUI. If you have DeepLabCut projects
+in TensorFlow, we've got you covered too: you can seamlessly switch to train your
+already labeled data by simply switching the engine (and thereby also compare
+performance). In short, expect a boost 🔥.
-In short, PyTorch models can be trained in any DeepLabCut project. If you have a project already made, simply add a new key to your
-project `config.yaml` file specifying `engine: pytorch`. Then any new training dataset
-that will be created will be a PyTorch model (see
-[Creating Shuffles and Model Configuration](#Creating-Shuffles-and-Model-Configuration))
-to learn more about training PyTorch models. To train Tensorflow models again, you can
-set `engine: tensorflow`.
+In short, PyTorch models can be trained in any DeepLabCut project. If you have a project
+already made, simply add a new key to your project `config.yaml` file specifying
+`engine: pytorch`. Then any new training dataset that will be created will be a PyTorch
+model (see [Creating Shuffles and Model Configuration](
+#Creating-Shuffles-and-Model-Configuration)) to learn more about training PyTorch
+models. To train Tensorflow models again, you can set `engine: tensorflow`.
### Installation
-During the alpha phase, you can use the `yaml` we provide, or create a new `env`.
-- If you are a beginner user, [please see these docs!](https://deeplabcut.github.io/DeepLabCut/docs/beginner-guides/beginners-guide.html)
-- If you are an advanced user, here is a quick start. [“Install PyTorch”](https://pytorch.org/get-started/locally/), then:
-```
-conda install -c conda-forge pytables==3.8.0
-pip install git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut[gui,modelzoo,wandb]
-```
+To see the DeepLabCut 3.0 installation guide, check the [installation docs](how-to-install).
### Using the GUI
-You can use the GUI to train DeepLabCut projects. However, you cannot switch between
-PyTorch and Tensorflow models while using the GUI. If you have set your engine to
-`pytorch`, then the GUI will only offer the creation of PyTorch shuffles.
-
-You can create `tensorflow` shuffles and train them again by setting the
-`engine: tensorflow` in the top right corner of the GUI.
+You can use the GUI to train DeepLabCut projects. You can switch between the PyTorch
+and TensorFlow engine through the drop-down menu in the top right corner.
## Major changes
@@ -76,17 +68,17 @@ parameters are not valid for the DLC 3.0 PyTorch API.
| API Method | Implemented | Parameters not yet implemented | Parameters invalid for pytorch |
|--------------------------------|:-----------:|-----------------------------------------------------------------------------------------------------|-----------------------------------------------------|
-| `train_network` | 🟢 | `keepdeconvweights` | `maxiters`, `saveiters`, `allow_growth`, `autotune` |
+| `train_network` | 🟢 | | `maxiters`, `saveiters`, `allow_growth`, `autotune` |
| `return_train_network_path` | 🟢 | | |
-| `evaluate_network` | 🟢 | `comparisonbodyparts`, `rescale`, `per_keypoint_evaluation` | |
+| `evaluate_network` | 🟢 | | |
| `return_evaluate_network_data` | 🔴 | | `TFGPUinference`, `allow_growth` |
-| `analyze_videos` | 🟢 | `in_random_order`, `dynamic`, `n_tracks`, `calibrate` | |
-| `create_tracking_dataset` | 🔴 | | |
-| `analyze_time_lapse_frames` | 🟠 | the name has changed to `analyze_images` to better reflect what it actually does (no video needed) | |
-| `convert_detections2tracklets` | 🟢 | `greedy`, `calibrate`, `window_size` | |
+| `analyze_videos` | 🟠 | `greedy`, `calibrate`, `window_size` | |
+| `create_tracking_dataset` | 🟢 | | |
+| `analyze_time_lapse_frames` | 🟢 | the name has changed to `analyze_images` to better reflect what it actually does (no video needed) | |
+| `convert_detections2tracklets` | 🟠 | `greedy`, `calibrate`, `window_size` | |
| `extract_maps` | 🟢 | | |
| `visualize_scoremaps` | 🟢 | | |
| `visualize_locrefs` | 🟢 | | |
| `visualize_paf` | 🟢 | | |
| `extract_save_all_maps` | 🟢 | | |
-| `export_model` | 🟢 | | |
\ No newline at end of file
+| `export_model` | 🟢 | | |
diff --git a/docs/pytorch_dlc.md b/docs/pytorch_dlc.md
new file mode 100644
index 0000000000..157a1c19af
--- /dev/null
+++ b/docs/pytorch_dlc.md
@@ -0,0 +1,167 @@
+# DeepLabCut: PyTorch API
+
+## Modules
+
+- [data](https://github.com/nastya236/DLCdev/blob/69005057eeac3c1492712863303f8268cee776e6/deeplabcut/pose_estimation_pytorch/data/project.py#L7):
+The `deeplabcut.pose_estimations_pytorch.data` package contains all code for pytorch
+dataset creation and test/train splitting.
+ - `Project` class provides train and test splitting and converts dataset to required
+ format. For instance, to [COCO]() format.
+ - `PoseTrainDataset` class is a [torch.utils.Dataset](https://pytorch.org/docs/stable/data.html) class, which converts raw
+ images and keypoints to a tensor dataset for training and evaluation.
+- [models](https://github.com/nastya236/DLCdev/blob/69005057eeac3c1492712863303f8268cee776e6/deeplabcut/pose_estimation_pytorch/data/models):
+The `deeplabcut.pose_estimations_pytorch.models` package contains all related to
+building a model with `backbone`, `neck` (optional) and `head`.
+- [train_module](https://github.com/nastya236/DLCdev/blob/69005057eeac3c1492712863303f8268cee776e6/deeplabcut/pose_estimation_pytorch/data/models):
+The `deeplabcut.pose_estimations_pytorch.train_module` contains all classes for model
+training and validation.
+
+## API
+
+The PyTorch implementation of DeepLabCut is very similar to the Tensorflow multi-animal
+implementation: the same steps need to be followed, just with slightly different API
+calls (and different model names).
+
+Up until it's time to create the training dataset, there are no changes to the way a
+PyTorch or Tensorflow project should be created.
+
+### Creating a Training Dataset
+
+To create a training dataset for a DeepLabCut PyTorch model, simply call:
+```python
+import deeplabcut
+deeplabcut.create_training_dataset(
+ path_config_file,
+ net_type="dekr_32",
+)
+```
+
+This will create folders for the training dataset in the same way as the Tensorflow
+version, with an addition configuration file in the `train` folder:
+`pytorch_config.yaml`. This is the file that can be edited to modify the model
+architecture or training parameters.
+
+There are currently two "families" of models implemented in PyTorch: DEKR (Geng, Zigang,
+et al. "Bottom-up human pose estimation via disentangled keypoint regression."
+Proceedings of the IEEE/CVF conference on computer vision and pattern recognition.
+2021.) and Tokenpose (Li, Yanjie, et al. "Tokenpose: Learning keypoint tokens for human
+pose estimation." Proceedings of the IEEE/CVF International conference on computer
+vision. 2021.). The choices of `net_type` that will create PyTorch training sets are:
+- `"dekr_16"`
+- `"dekr_32"`
+- `"dekr_48"`
+- `"token_pose_w16"`
+- `"token_pose_w32"`
+- `"token_pose_w48"`
+
+Note that Tokenpose models cannot currently be used with projects that contain unique
+keypoints.
+
+### Training the network
+Training a PyTorch model is done in a very similar manner as a tensorflow model, though
+currently the PyTorch API needs to be called directly:
+```python
+import deeplabcut.pose_estimation_pytorch.apis as api
+api.train_network(config_path, shuffle=1, trainingsetindex=0)
+```
+
+**Parameters**
+```
+config : path to the yaml config file of the project
+shuffle : index of the shuffle we want to train on
+trainingsetindex : training set index
+transform: Augmentation pipeline for the images
+ if None, the augmentation pipeline is built from config files
+ Advice if you want to use custom transformations:
+ Keep in mind that in order for transfer learning to be efficient, your
+ data statistical distribution should resemble the one used to pretrain your backbone
+ In most cases (e.g backbone was pretrained on ImageNet), that means it should be Normalized with
+ A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
+transform_cropped: Augmentation pipeline for the cropped images around animals
+ if None, the augmentation pipeline is built from config files
+ Advice if you want to use custom transformations:
+ Keep in mind that in order for transfer learning to be efficient, your
+ data statistical distribution should resemble the one used to pretrain your backbone
+ In most cases (e.g backbone was pretrained on ImageNet), that means it should be Normalized with
+ A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
+modelprefix: directory containing the deeplabcut configuration files to use
+ to train the network (and where snapshots will be saved). By default, they
+ are assumed to exist in the project folder.
+snapshot_path: if resuming training, used to specify the snapshot from which to resume
+detector_path: if resuming training of a top down model, used to specify the detector snapshot from
+ which to resume
+**kwargs : could be any entry of the pytorch_config dictionary. Examples are
+ to see the full list see the pytorch_cfg.yaml file in your project folder
+```
+
+### Evaluating the network
+As for training, the main difference is the need to call the API directly.
+```python
+import deeplabcut.pose_estimation_pytorch.apis as api
+api.evaluate_network(config_path, shuffle=1, trainingsetindex="all")
+```
+
+**Parameters**
+```
+config: path to the project's config file
+shuffles: Iterable of integers specifying the shuffle indices to evaluate.
+trainingsetindex: Integer specifying which training set fraction to use.
+ Evaluates all fractions if set to "all"
+snapshotindex: index (starting at 0) of the snapshot we want to load. To
+ evaluate the last one, use -1. To evaluate all snapshots, use "all". For
+ example if we have 3 models saved
+ - snapshot-0.pt
+ - snapshot-50.pt
+ - snapshot-100.pt
+ and we want to evaluate snapshot-50.pt, snapshotindex should be 1. If None,
+ the snapshotindex is loaded from the project configuration.
+plotting: Plots the predictions on the train and test images. If provided it must
+ be either ``True``, ``False``, ``"bodypart"``, or ``"individual"``. Setting
+ to ``True`` defaults as ``"bodypart"`` for multi-animal projects.
+show_errors: display train and test errors.
+transform: transformation pipeline for evaluation
+ ** Should normalise the data the same way it was normalised during training **
+modelprefix: directory containing the deeplabcut models to use when evaluating
+ the network. By default, they are assumed to exist in the project folder.
+batch_size: the batch size to use for evaluation
+```
+
+### Analyzing novel videos
+One big difference between the PyTorch and Tensorflow implementations comes in the way
+animal assembly happens (for multi-animal models). While in Tensorflow, assembly was a
+separate step that needed to be done from the keypoints, in the PyTorch version it's
+integrated directly into the models. From an API standpoint, that does not change much.
+
+Again, the PyTorch API needs to be invoked directly (it also has the `auto_track`
+option).
+```python
+import deeplabcut.pose_estimation_pytorch.apis as api
+api.analyze_videos(config_path, ["/fullpath/project/videos/test.mp4"], videotype=".mp4")
+```
+
+The PyTorch detections need to be converted to tracklets using the PyTorch API, but then
+the original tracklet stitching can be used.
+```python
+import deeplabcut
+import deeplabcut.pose_estimation_pytorch.apis as api
+api.convert_detections2tracklets(
+ config_path,
+ videos=['/fullpath/project/videos/test.mp4'],
+ videotype=".mp4",
+)
+deeplabcut.stitch_tracklets(
+ config_path,
+ videos=['/fullpath/project/videos/test.mp4'],
+ videotype=".mp4",
+)
+```
+
+Creating labeled videos can then be called in exactly the same way as before.
+```python
+import deeplabcut
+deeplabcut.create_labeled_video(
+ config_path,
+ videos=['/fullpath/project/videos/test.mp4'],
+ videotype=".mp4",
+)
+```
diff --git a/docs/quick-start/single_animal_quick_guide.md b/docs/quick-start/single_animal_quick_guide.md
index c34268f2dc..6e4090d854 100644
--- a/docs/quick-start/single_animal_quick_guide.md
+++ b/docs/quick-start/single_animal_quick_guide.md
@@ -4,17 +4,21 @@
Open ipython in the terminal:
```
ipython
+```
+
+Import DeepLabCut:
+```
import deeplabcut
```
Create a new project:
```
-deeplabcut.create_new_project(‘project_name’,‘experimenter’,[‘path of video 1’,‘path of video2’,..])
+deeplabcut.create_new_project("project_name", "experimenter", ["path of video 1", "path of video2", ..])
```
Set a config_path variable for ease of use + go edit this file!:
```
-config_path = ‘yourdirectory/project_name/config.yaml’
+config_path = "yourdirectory/project_name/config.yaml"
```
Extract frames:
@@ -49,20 +53,20 @@ deeplabcut.evaluate_network(config_path)
Video analysis:
```
-deeplabcut.analyze_videos(config_path, [‘path of video 1’,‘path of video2’, ...])
+deeplabcut.analyze_videos(config_path, ["path of video 1", "path of video2", ..])
```
Filter predictions [OPTIONAL]:
```
-deeplabcut.filterpredictions(config_path, [‘path of video 1’,‘path of video2’,..])
+deeplabcut.filterpredictions(config_path, ["path of video 1", "path of video2", ..])
```
Plot results (trajectories):
```
-deeplabcut.plot_trajectories(config_path, [‘path of video 1’,‘path of video2’,..], filtered=True)
+deeplabcut.plot_trajectories(config_path, ["path of video 1", "path of video2", ..], filtered=True)
```
Create a video:
```
-deeplabcut.create_labeled_video(config_path, [‘path of video 1’,‘path of video2’,..]filtered=True)
+deeplabcut.create_labeled_video(config_path, ["path of video 1", "path of video2", ..], filtered=True)
```
diff --git a/docs/quick-start/tutorial_maDLC.md b/docs/quick-start/tutorial_maDLC.md
index 947e036513..507ad559d8 100644
--- a/docs/quick-start/tutorial_maDLC.md
+++ b/docs/quick-start/tutorial_maDLC.md
@@ -69,7 +69,17 @@ deeplabcut.create_multianimaltraining_dataset(
```
**(7) Train the network**
+
```python
+# PyTorch Engine
+deeplabcut.train_network(
+ config_path,
+ device="cuda",
+ save_epochs=5,
+ epochs=200,
+)
+
+# TensorFlow Engine
deeplabcut.train_network(
config_path,
saveiters=10000,
diff --git a/docs/recipes/BatchProcessing.md b/docs/recipes/BatchProcessing.md
index abac0bf27d..b0f91bbee6 100644
--- a/docs/recipes/BatchProcessing.md
+++ b/docs/recipes/BatchProcessing.md
@@ -3,10 +3,12 @@
## Tips for working with DLC networks:
-- Now you have a DLC network and are happy with the performance on selected videos, you may want to run it on all your videos without hassle. If all your videos are in one folder this is easy, simply pass the foldername to `deeplabcut.analyze_videos(config,[folder])` and you are fine. What if the videos are scattered?
+Now you have a DLC network and are happy with the performance on selected videos, you may want to run it on all your
+videos without hassle. If all your videos are in one folder this is easy, simply pass the foldername to
+`deeplabcut.analyze_videos(config,[folder])` and you are fine. What if the videos are scattered?
-
-You could create a simply script that runs over all your video folders with the network of choice. Your "key" to this network is your config.yaml file.
+You can create a simple script that runs over all your video folders with the network of choice. Your "key" to this
+network is your config.yaml file.

@@ -33,18 +35,18 @@ import deeplabcut
def getsubfolders(folder):
''' returns list of subfolders '''
- return [os.path.join(folder,p) for p in os.listdir(folder) if os.path.isdir(os.path.join(folder,p))]
+ return [os.path.join(folder, p) for p in os.listdir(folder) if os.path.isdir(os.path.join(folder, p))]
-project='ComplexWheelD3-12-Fumi-2019-01-28'
+project = "ComplexWheelD3-12-Fumi-2019-01-28"
-shuffle=1
+shuffle = 1
-prefix='/home/alex/DLC-workshopRowland'
+prefix = "/home/alex/DLC-workshopRowland"
-projectpath=os.path.join(prefix,project)
-config=os.path.join(projectpath,'config.yaml')
+projectpath = os.path.join(prefix, project)
+config = os.path.join(projectpath, "config.yaml")
-basepath='/home/alex/BenchmarkingExperimentsJan2019' #data'
+basepath = "/home/alex/BenchmarkingExperimentsJan2019"
'''
@@ -57,13 +59,13 @@ Imagine that the data (here: videos of 3 different types) are in subfolders:
'''
-subfolders=getsubfolders(basepath)
+subfolders = getsubfolders(basepath)
for subfolder in subfolders: #this would be January, February etc. in the upper example
- print("Starting analyze data in:", subfolder)
- subsubfolders=getsubfolders(subfolder)
+ print("Starting analyze data in: ", subfolder)
+ subsubfolders = getsubfolders(subfolder)
for subsubfolder in subsubfolders: #this would be Febuary1, etc. in the upper example...
- print("Starting analyze data in:", subsubfolder)
- for vtype in ['.mp4','.m4v','.mpg']:
+ print("Starting analyze data in: ", subsubfolder)
+ for vtype in [".mp4", ".m4v", ".mpg"]:
deeplabcut.analyze_videos(config,[subsubfolder],shuffle=shuffle,videotype=vtype,save_as_csv=True)
```
@@ -90,25 +92,25 @@ import os
import deeplabcut
-Maxiter=int(1.5*10**5)
+epochs = 200
model=int(sys.argv[1])
-Projects=[['project1-phoenix-2019-01-28'],['ComplexWheelD3-12-Fumi-2019-01-28', 'maze-ariel-2019-01-28'], ['TBI-BvA-2019-01-28','group-eli-2019-01-28']]
+Projects=[["project1-phoenix-2019-01-28"], ["ComplexWheelD3-12-Fumi-2019-01-28", "maze-ariel-2019-01-28"], ["TBI-BvA-2019-01-28", "group-eli-2019-01-28"]]
shuffle=1
-prefix='/home/alex/DLC-workshopRowland'
+prefix = "/home/alex/DLC-workshopRowland"
for project in Projects[model]:
- projectpath=os.path.join(prefix,project)
- config=os.path.join(projectpath,'config.yaml')
+ projectpath = os.path.join(prefix, project)
+ config = os.path.join(projectpath, "config.yaml")
- cfg=deeplabcut.auxiliaryfunctions.read_config(config)
- previous_path=cfg['project_path']
+ cfg = deeplabcut.auxiliaryfunctions.read_config(config)
+ previous_path = cfg["project_path"]
- cfg['project_path']=projectpath
- deeplabcut.auxiliaryfunctions.write_config(config,cfg)
+ cfg["project_path"]=projectpath
+ deeplabcut.auxiliaryfunctions.write_config(config, cfg)
print("This is the name of the script: ", sys.argv[0])
print("Shuffle: ", shuffle)
@@ -116,22 +118,18 @@ for project in Projects[model]:
deeplabcut.create_training_dataset(config, Shuffles=[shuffle])
- deeplabcut.train_network(config, shuffle=shuffle, max_snapshots_to_keep=5, maxiters=Maxiter)
+ deeplabcut.train_network(config, shuffle=shuffle, max_snapshots_to_keep=5, epochs=epochs)
print("Evaluating...")
- deeplabcut.evaluate_network(config, Shuffles=[shuffle],plotting=True)
+ deeplabcut.evaluate_network(config, Shuffles=[shuffle], plotting=True)
print("Analyzing videos..., switching to last snapshot...")
- #cfg=deeplabcut.auxiliaryfunctions.read_config(config)
- #cfg['snapshotindex']=-1
- #deeplabcut.auxiliaryfunctions.write_config(config,cfg)
-
for vtype in ['.mp4','.m4v','.mpg']:
try:
- deeplabcut.analyze_videos(config,[str(os.path.join(projectpath,'videos'))],shuffle=shuffle,videotype=vtype,save_as_csv=True)
+ deeplabcut.analyze_videos(config, [str(os.path.join(projectpath, "videos"))], shuffle=shuffle, videotype=vtype, save_as_csv=True)
except:
pass
print("DONE WITH ", project," resetting to original path")
- cfg['project_path']=previous_path
- deeplabcut.auxiliaryfunctions.write_config(config,cfg)
+ cfg["project_path"] = previous_path
+ deeplabcut.auxiliaryfunctions.write_config(config, cfg)
```
diff --git a/docs/recipes/ClusteringNapari.md b/docs/recipes/ClusteringNapari.md
index 4d309d74e5..00a3b9f514 100644
--- a/docs/recipes/ClusteringNapari.md
+++ b/docs/recipes/ClusteringNapari.md
@@ -1,62 +1,88 @@
# Clustering in the napari-DeepLabCut GUI
-To increase model performance, one can find the errors in the user-defined label (or in output H5 files after video inference). You can correct the errors and add them back into the training dataset, a process called active learning.
+To increase model performance, one can find the errors in the user-defined label (or in output H5 files after video
+inference). You can correct the errors and add them back into the training dataset, a process called active learning.
-User errors can be detrimental to model performance, so beyond just `check_labels`, this tool allows you to find your mistakes. If you are curious about how errors affect performance, read the paper: [A Primer on Motion Capture with Deep Learning: Principles, Pitfalls, and Perspectives](https://www.sciencedirect.com/science/article/pii/S0896627320307170).
+User errors can be detrimental to model performance, so beyond just `check_labels`, this tool allows you to find your
+mistakes. If you are curious about how errors affect performance, read the paper:
+[A Primer on Motion Capture with Deep Learning: Principles, Pitfalls, and Perspectives](https://www.sciencedirect.com/science/article/pii/S0896627320307170).
**TL;DR: your data quality matters!**
-
+
```{Hint}
**Labeling Pitfalls: How Corruptions Affect Performance**
-(A) Illustration of two types of labeling errors. Top is ground truth, middle is missing a label at the tailbase, and bottom is if the labeler swapped the ear identity (left to right, etc.). (B) Using a small training dataset of 106 frames, how do the corruptions in (A) affect the percent of correct keypoints (PCK) on the test set as the distance to ground truth increases from 0 pixels (perfect prediction) to 20 pixels (larger error)? The x axis denotes the difference in the ground truth to the predicted location (RMSE in pixels), whereas the y axis is the fraction of frames considered accurate (e.g., z80% of frames fall within 9 pixels, even on this small training dataset, for points that are not corrupted, whereas for swapped points this falls to z65%). The fraction of the dataset that is corrupted affects this value. Shown is when missing the tailbase label (top) or swapping the ears in 1%, 5%, 10%, and 20% of frames (of 106 labeled training images). Swapping versus missing labels has a more notable adverse effect on network performance.
+(A) Illustration of two types of labeling errors. Top is ground truth, middle is missing a label at the tailbase, and
+bottom is if the labeler swapped the ear identity (left to right, etc.). (B) Using a small training dataset of 106
+frames, how do the corruptions in (A) affect the percent of correct keypoints (PCK) on the test set as the distance
+to ground truth increases from 0 pixels (perfect prediction) to 20 pixels (larger error)? The x axis denotes the
+difference in the ground truth to the predicted location (RMSE in pixels), whereas the y axis is the fraction of
+frames considered accurate (e.g., z80% of frames fall within 9 pixels, even on this small training dataset, for
+points that are not corrupted, whereas for swapped points this falls to z65%). The fraction of the dataset that is
+corrupted affects this value. Shown is when missing the tailbase label (top) or swapping the ears in 1%, 5%, 10%,
+and 20% of frames (of 106 labeled training images). Swapping versus missing labels has a more notable adverse effect
+on network performance.
```
-The DeepLabCut toolbox supports **active learning** by extracting outlier frames be several methods and allowing the user to correct the frames, then retrain the model. See the [Nature Protocols paper](https://www.nature.com/articles/s41596-019-0176-0) for the detailed steps, or in the docs, [here](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#m-optional-active-learning-network-refinement-extract-outlier-frames).
+The DeepLabCut toolbox supports **active learning** by extracting outlier frames be several methods and allowing the
+user to correct the frames, then retrain the model. See the
+[Nature Protocols paper](https://www.nature.com/articles/s41596-019-0176-0) for the detailed steps, or in the docs,
+[here](active-learning).
-To facilitate this process, here we propose a new way to detect 'outlier frames', which is planned to be released in ~Sept 2022. Your contributions and suggestions are welcomed, so test the [PR](https://github.com/DeepLabCut/napari-deeplabcut/pull/38) and give us feedback!
+To facilitate this process, here we propose a new way to detect 'outlier frames'.
+Your contributions and suggestions are welcomed, so test the
+[PR](https://github.com/DeepLabCut/napari-deeplabcut/pull/38) and give us feedback!
-This #cookbook recipe aims to show a usecase of **clustering in napari** and is contributed by 2022 DLC AI Resident [Sabrina Benas](https://twitter.com/Sabrineiitor) 💜.
+This #cookbook recipe aims to show a usecase of **clustering in napari** and is contributed by 2022 DLC AI Resident
+[Sabrina Benas](https://twitter.com/Sabrineiitor) 💜.
## Detect Outliers to Refine Labels
### Open `napari` and the `DeepLabCut plugin`
- - Then open your `CollectedData_.h5` file. We used the Horse-30 dataset, presented in [Mathis, Biasi et al. WACV 2022](http://horse10.deeplabcut.org/), as our demo and development set. Here is an example of what it should look like:
+Then open your `CollectedData_.h5` file. We used the Horse-30 dataset, presented in
+[Mathis, Biasi et al. WACV 2022](http://horse10.deeplabcut.org/), as our demo and development set. Here is an example of what it should look like:
-
+
### Clustering
-- Click on the button `cluster` and wait a few seconds until it displays a new layer with the cluster:
+Click on the button `cluster` and wait a few seconds until it displays a new layer with the cluster:
-
+
You can click on a point and see the image on the right with the keypoints:
-
+
### Visualize & refine
-If you decided to refine that frame (we moved the points to make outliers obvious), click `show img` and refine them using the plugin features and instructions:
+If you decided to refine that frame (we moved the points to make outliers obvious), click `show img` and refine them
+using the plugin features and instructions:
-
+
- ```{Attention}
- When you're done, you need to click `ctl-s` to save it.
+```{Attention}
+When you're done, you need to click `ctl-s` to save it.
```
-- You can go back to the cluster layer by clicking on `close img` and refine another image. Reminder, when you're done editing you need to click `ctl-s` to save your work. And now you can take the updated `CollectedData` file, create and **new training shuffle**, and train the network! Read more about how to [create a training dataset](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#f-create-training-dataset-s).
+You can go back to the cluster layer by clicking on `close img` and refine another image. Reminder, when you're done
+editing you need to click `ctl-s` to save your work. And now you can take the updated `CollectedData` file, create
+and **new training shuffle**, and train the network! Read more about how to
+[create a training dataset](create-training-dataset).
```{hint}
-If you want to change the clustering method, you can modify the file [kmeans.py](https://github.com/DeepLabCutAIResidency/napari-deeplabcut/blob/cluster1/src/napari_deeplabcut/kmeans.py)
+If you want to change the clustering method, you can modify the file
+[kmeans.py](https://github.com/DeepLabCutAIResidency/napari-deeplabcut/blob/cluster1/src/napari_deeplabcut/kmeans.py)
+```
::::{important}
-You have to keep the way the file is opened (pandas dataframe) and the output has to be the cluster points, the points colors in the cluster colors and the frame names (in this order).
+You have to keep the way the file is opened (pandas dataframe) and the output has to be the cluster points, the points
+colors in the cluster colors and the frame names (in this order).
::::
```
diff --git a/docs/recipes/DLCMethods.md b/docs/recipes/DLCMethods.md
index b2f7a5a7af..74d5ec4c65 100644
--- a/docs/recipes/DLCMethods.md
+++ b/docs/recipes/DLCMethods.md
@@ -2,7 +2,13 @@
**Pose estimation using DeepLabCut**
-For body part tracking we used DeepLabCut (version 2.X.X) [Mathis et al, 2018, Nath et al, 2019]. Specifically, we labeled X number of frames taken from X videos/animals (then X% was used for training (default is 95%). We used a X-based neural network (i.e., X = ResNet-50, ResNet-101, MobileNetV2-0.35, MobileNetV2-0.5, MobileNetV2-0.75, MobileNetV2-1, EfficientNet ..X, dlcrnet_ms5, etc.)*** with default parameters* for X number of training iterations. We validated with X number of shuffles, and found the test error was: X pixels, train: X pixels (image size was X by X). We then used a p-cutoff of X (i.e. 0.9) to condition the X,Y coordinates for future analysis. This network was then used to analyze videos from similar experimental settings.
+For body part tracking we used DeepLabCut (version 3.X.X) [Mathis et al, 2018, Nath et al, 2019]. Specifically, we
+labeled X number of frames taken from X videos/animals (then X% was used for training (default is 95%). We used a
+X-based neural network (i.e., X = ResNet-50, ResNet-101, MobileNetV2-0.35, MobileNetV2-0.5, MobileNetV2-0.75,
+MobileNetV2-1, EfficientNet ..X, dlcrnet_ms5, cspnext_s, dekr_w32, rtmpose_s, etc.)*** with default parameters* for X
+number of training iterations. We validated with X number of shuffles, and found the test error was: X pixels, train:
+X pixels (image size was X by X). We then used a p-cutoff of X (i.e. 0.9) to condition the X,Y coordinates for future
+analysis. This network was then used to analyze videos from similar experimental settings.
*If any defaults were changed in *`pose_config.yaml`*, mention them.
@@ -43,4 +49,5 @@ If you use ResNets, consider citing Insafutdinov et al 2016 & He et al 2016. If
> 770–778 (2016). URL https://arxiv.org/abs/
> 1512.03385.
-We also have the network graphic freely available on SciDraw.io if you'd like to use it! https://scidraw.io/drawing/290. If you use our DLC logo, please include the TM symbol, thank you!
+We also have the network graphic freely available on SciDraw.io if you'd like to use it! https://scidraw.io/drawing/290.
+If you use our DLC logo, please include the TM symbol, thank you!
diff --git a/docs/recipes/MegaDetectorDLCLive.md b/docs/recipes/MegaDetectorDLCLive.md
index 20e35a38fc..8af6f59d1b 100644
--- a/docs/recipes/MegaDetectorDLCLive.md
+++ b/docs/recipes/MegaDetectorDLCLive.md
@@ -14,7 +14,9 @@ MegaDetector detects an animal and generates a bounding box around the animal. T
## DeepLabCut-Live
-DeepLabCut-Live! is a real-time package for running DeepLabCut. However, you can also use it as a lighter-weight package for running DeeplabCut even if you don't need real-time. It's very useful to use in HPC or servers, or in Apps, as we do here. To read more, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/deeplabcutlive.html).
+DeepLabCut-Live! is a real-time package for running DeepLabCut. However, you can also use it as a lighter-weight
+package for running DeeplabCut even if you don't need real-time. It's very useful to use in HPC or servers, or in Apps,
+as we do here. To read more, check out the [docs](deeplabcut-live).
### MegaDetector meets DeepLabCut
diff --git a/docs/recipes/OpenVINO.md b/docs/recipes/OpenVINO.md
index 2033045ce8..78ea18e82f 100644
--- a/docs/recipes/OpenVINO.md
+++ b/docs/recipes/OpenVINO.md
@@ -1,10 +1,15 @@
# Intel OpenVINO backend
+::::{warning}
+This feature is currently implemented for TensorFlow-based models only.
+::::
+
DeepLabCut provides an option to run deep learning model with [OpenVINO](https://github.com/openvinotoolkit/openvino) backend.
-To enable OpenVINO in your pipeline, use `use_openvino` flag of `analyze_videos` method with one of string values indicating device:
-* "CPU" - Use CPU. This is a default value.
-* "GPU" - Use iGPU (requires OpenCL to be installed). First launch might take some time for kernels initialization.
-* "MULTI:CPU,GPU" - Use CPU and GPU simultaneously. In most cases this option provides the best efficiency.
+To enable OpenVINO in your pipeline, use `use_openvino` flag of `analyze_videos` method with one of string values
+indicating device:
+* ```"CPU"``` - Use CPU. This is a default value.
+* ```"GPU"``` - Use GPU (requires OpenCL to be installed). First launch might take some time for kernels initialization.
+* ```"MULTI:CPU,GPU"``` - Use CPU and GPU simultaneously. In most cases this option provides the best efficiency.
```python
def analyze_videos(
diff --git a/docs/recipes/OtherData.md b/docs/recipes/OtherData.md
index 1d9e648d70..73343284b3 100644
--- a/docs/recipes/OtherData.md
+++ b/docs/recipes/OtherData.md
@@ -9,12 +9,12 @@ Some users may have annotation data in different formats, yet want to use the DL
Here is a guide to do this via the ".csv" route: (the pandas array route is identical, just format the pandas array in the same way).
-**Step 1**: create a project as describe in the user guide: https://github.com/AlexEMG/DeepLabCut/blob/master/docs/UseOverviewGuide.md#create-a-new-project
+**Step 1**: create a project as describe in the user guide: https://github.com/DeepLabCut/DeepLabCut/blob/main/docs/UseOverviewGuide.md#create-a-new-project
**Step 2**: edit the ``config.yaml`` file to include the body part names, please take care that spelling, spacing, and capitalization are IDENTICAL to the "labeled data body part names".
-**Step 3**: Please inspect the excel formatted sheet (.csv) from our [demo project](https://github.com/AlexEMG/DeepLabCut/tree/master/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1)
-- i.e. this file: https://github.com/AlexEMG/DeepLabCut/blob/master/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1/CollectedData_Mackenzie.csv
+**Step 3**: Please inspect the excel formatted sheet (.csv) from our [demo project](https://github.com/DeepLabCut/DeepLabCut/tree/main/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1)
+- i.e. this file: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/Reaching-Mackenzie-2018-08-30/labeled-data/reachingvideo1/CollectedData_Mackenzie.csv
**Step 4**: Edit the .csv file such that it contains the X, Y pixel coordinates, the body part names, the scorer name as well as the relative path to the image: e.g. /labeled-data/somefolder/img017.jpg
Then make sure the scorer name, and body parts are the same in the config.yaml file.
diff --git a/docs/recipes/TechHardware.md b/docs/recipes/TechHardware.md
index f4610f5442..879fc67dfb 100644
--- a/docs/recipes/TechHardware.md
+++ b/docs/recipes/TechHardware.md
@@ -23,10 +23,20 @@ The software is very robust to track data from any camera (cell phone cameras, g
**Anaconda/Python3:** Anaconda: a free and open source distribution of the Python programming language (download from https://www.anaconda.com/). DeepLabCut is written in Python 3 (https://www.python.org/) and not compatible with Python 2.
-
-**TensorFlow** You will need [TensorFlow](https://www.tensorflow.org/) (we used version 1.0 in the paper, later versions also work with the provided code (we tested **TensorFlow versions 1.0 to 1.15, and 2.0 to 2.5**; we recommend TF2.5 now) for Python 3.7, 3.8, or 3.9 with GPU support.
-
-To note, is it possible to run DeepLabCut on your CPU, but it will be VERY slow (see: [Mathis & Warren](https://www.biorxiv.org/content/early/2018/10/30/457242)). However, this is the preferred path if you want to test DeepLabCut on your own computer/data before purchasing a GPU, with the added benefit of a straightforward installation! Otherwise, use our COLAB notebooks for GPU access for testing.
-
-Docker: We highly recommend advaced users use the supplied [Docker container](https://github.com/MMathisLab/Docker4DeepLabCut2.0).
-NOTE: [this container does not work on windows hosts!](https://github.com/NVIDIA/nvidia-docker/issues/43)
+**For the TensorFlow Engine:** You will need [TensorFlow](https://www.tensorflow.org/).
+We used version 1.0 in the paper, later versions also work with the provided code (we
+tested **TensorFlow versions 1.0 to 1.15, and 2.0 to 2.12 (2.10 for Windows)**; we
+recommend TF2.12 for MacOS/Ubuntu and 2.10 for Windows) for Python 3.10 with GPU
+support.
+
+To note, is it possible to run DeepLabCut on your CPU, but it will be VERY slow (see:
+[Mathis & Warren](https://www.biorxiv.org/content/early/2018/10/30/457242)). However, this is the preferred path if you want to test
+DeepLabCut on your own computer/data before purchasing a GPU, with the added benefit of
+a straightforward installation! Otherwise, use our COLAB notebooks for GPU access for
+testing.
+
+Docker: We highly recommend advanced users use the supplied [Docker container](
+docker-containers).
+
+NOTE: [Currently GPU support in Docker Desktop is only available on Windows with the
+WSL2 backend.](https://docs.docker.com/desktop/features/gpu/)
diff --git a/docs/recipes/UsingModelZooPupil.md b/docs/recipes/UsingModelZooPupil.md
index 2aac4a4d00..d73a1acbd7 100644
--- a/docs/recipes/UsingModelZooPupil.md
+++ b/docs/recipes/UsingModelZooPupil.md
@@ -1,17 +1,25 @@
# Using ModelZoo models on your own datasets
-Animal behavior has to be analyzed with painstaking accuracy. Therefore, animal pose estimation has been an important tool to study animal behavior precisely.
+
Animal behavior has to be analyzed with painstaking accuracy. Therefore, animal pose estimation has been
+an important tool to study animal behavior precisely.
-Beside providing an open source toolbox for researchers to develop customized deep neural networks for markerless pose estimation, we at DeepLabCut also aim to build robust, generalizable models. Part of this effort is via the [DeeplabCut ModelZoo](http://www.mackenziemathislab.org/dlc-modelzoo).
+Beside providing an open source toolbox for researchers to develop customized deep neural networks for markerless pose
+estimation, we at DeepLabCut also aim to build robust, generalizable models. Part of this effort is via the
+[DeeplabCut ModelZoo](http://modelzoo.deeplabcut.org/).
-The Zoo hosts user-contributed and #teamDLC developed models that are trained on specific animals and scenarios. You can analyze your videos directly with these models without training. The models have strong zero-shot performance on unseen out-of-domain data which can be further improved via pseudo-labeling. Please check the first [ModelZoo manuscript](https://arxiv.org/abs/2203.07436v1) for further details.
+The Zoo hosts user-contributed and DLC-team developed models that are trained on specific animals and scenarios. You can
+analyze your videos directly with these models without training. The models have strong zero-shot performance on unseen
+out-of-domain data which can be further improved via pseudo-labeling. Please check the first
+[ModelZoo manuscript](https://arxiv.org/abs/2203.07436v1) for further details.
-This recipe aims to show a usecase of the **mouse_pupil_vclose** and is contributed by 2022 DLC AI Resident [Neslihan Wittek](https://github.com/neslihanedes) 💜.
+This recipe aims to show a usecase of the **mouse_pupil_vclose** and is contributed by 2022 DLC AI Resident
+[Neslihan Wittek](https://github.com/neslihanedes) 💜.
## `mouse_pupil_vclose` model
This model was contributed by Jim McBurney-Lin at University of California Riverside, USA.
-The model was trained on images of C57/B6J mice eyes, and also then augmented with mouse eye data from the Mathis Lab at EPFL.
+The model was trained on images of C57/B6J mice eyes, and also then augmented with mouse eye data from the Mathis Lab at
+EPFL.
@@ -30,23 +38,37 @@ The model was trained on images of C57/B6J mice eyes, and also then augmented wi
| 8 | VLpupil | Ventral/left aspect of pupil |
-Since we would like to evaluate the models performance on out-of-domain data, we will analyze pigeon pupils. For more discussions and work on so-called out-of-domain data, see [Mathis, Biasi 2020](http://www.mackenziemathislab.org/horse10).
+Since we would like to evaluate the models performance on out-of-domain data, we will analyze pigeon pupils. For more
+discussions and work on so-called out-of-domain data, see
+[Mathis, Biasi 2020](https://paperswithcode.com/dataset/horse-10).
## Pigeon Pupil
-The eye pupil admits and regulates the amount of light entering the retina in order to enable image perception. Beside this curicial role, the pupil also reflects the state of the brain. The systemic behavior of the pupil has not been vastly studied in birds, although researchers from Max Planck Institute for Ornithology in Seewiesen have shed light on pupil behaviors in pigeons.
+The eye pupil admits and regulates the amount of light entering the retina in order to enable image perception. Beside
+this curicial role, the pupil also reflects the state of the brain. The systemic behavior of the pupil has not been
+vastly studied in birds, although researchers from
+Max Planck Institute for Ornithology in Seewiesen
+have shed light on pupil behaviors in pigeons.
-The pupils of male pigeons get smaller during courtship behavior. This is in contrast to mammals, for which the pupil size dilates in response to an increase in arousal. In addition, the pupil size of pigeons dilates during non-REM sleep, while they rapidly constrict during REM sleep. Examining these differences and the reason behind them, might be helpful to understand the pupillary behavior in general.
+The pupils of male pigeons get smaller during courtship behavior. This is in contrast to mammals, for which the pupil
+size dilates in response to an increase in arousal. In addition, the pupil size of pigeons dilates during non-REM sleep,
+while they rapidly constrict during REM sleep. Examining these differences and the reason behind them, might be helpful
+to understand the pupillary behavior in general.
-In light of these findings, we wanted to show whether the **mouse_pupil_vclose** model give us an accurate tracking performance for the pigeon pupil as well.
+In light of these findings, we wanted to show whether the **mouse_pupil_vclose** model give us an accurate tracking
+performance for the pigeon pupil as well.
### Jupyter & Google Colab Notebook
-DeepLabCut provides a Google Colab Notebook to analyze your video with a pretrained networks from the ModelZoo. No need for local installation of DeepLabCut!
+DeepLabCut provides a Google Colab Notebook to analyze your video with a pretrained networks from the ModelZoo. No need
+for local installation of DeepLabCut!
-Since we are interested in the accuracy of the **mouse_pupil_vclose** on pigeon pupil data, we will use a video which consists of 7 recordings of pigeon pupils.
+Since we are interested in the accuracy of the **mouse_pupil_vclose** on pigeon pupil data, we will use a video which
+consists of 7 recordings of pigeon pupils.
-Check ModelZoo Colab page and a video tutorial on how to use the ModelZoo on Google Colab.
+Check the
+[ModelZoo Colab page](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_DLC_ModelZoo.ipynb)
+and a video tutorial on how to use the ModelZoo on Google Colab.
@@ -63,35 +85,39 @@ files.download("/content/file.zip")
### Analyze Videos at Your Local Machine
-DeepLabCut host models from the
DeepLabCut ModelZoo Project .
+DeepLabCut host models from the [DeepLabCut ModelZoo Project](http://modelzoo.deeplabcut.org/).
The `create_pretrained_project` function will create a new project directory with the necessary sub-directories and a basic configuration file.
It will also initialize your project with a pre-trained model from the DeepLabCut ModelZoo.
The rest of the code should be run within your DeepLabCut environment.
-Check
here for the instructions for the DeepLabCut installation.
+Check [here](how-to-install) for the instructions for the DeepLabCut installation.
+To initialize a new project directory with a pre-trained model from the DeepLabCut ModelZoo, run the code below.
+
+::::{warning}
+This method is currently implemented for Tensorflow only, Pytorch compatibility is coming soon.
+::::
```python
import deeplabcut
-```
-To initialize a new project directory with a pre-trained model from the DeepLabCut ModelZoo, run the code below.
-```python
deeplabcut.create_pretrained_project(
"projectname",
"experimenter",
[r"path_for_the_videos"],
- model= "mouse_pupil_vclose",
- working_directory= r"project_directory",
- copy_videos= True,
- videotype= ".mp4 or .avi?",
- analyzevideo= True,
- filtered= True,
- createlabeledvideo= True,
- trainFraction= None
+ model="mouse_pupil_vclose",
+ working_directory=r"project_directory",
+ copy_videos=True,
+ videotype=".mp4 or .avi?",
+ analyzevideo=True,
+ filtered=True,
+ createlabeledvideo=True,
+ trainFraction=None,
+ engine=deeplabcut.Engine.TF,
)
```
+
::::{important}
Your videos should be cropped around the eye for better model accuracy! 👁🐭
::::
@@ -100,13 +126,12 @@ Excitingly, 6 out of the 7 pigeon pupils were tracked nicely:
-
-When we further evaluate the model accuracy by checking the likelihood of tracked points, we see that the tracking is low confidience when the pigeons close their eyelid (which is of course expected, and can be leveraged to measure blinking 👁).
-
+When we further evaluate the model accuracy by checking the likelihood of tracked points, we see that the tracking is
+low confidience when the pigeons close their eyelid (which is of course expected, and can be leveraged to measure
+blinking 👁).
-
But you also might encounter larger problems than small tracking glitches:
@@ -117,12 +142,26 @@ The more problems you encounter, the higher the number of frames you might want
You should also add the path of the video(s) into the `config.yaml` file, or run the following command to add the videos to your project:
```python
-deeplabcut.add_new_videos('/pathofproject/config.yaml', ['/pathofvideos/pigeon.mp4'], copy_videos=False, coords=None, extract_frames=False)
+deeplabcut.add_new_videos(
+ "/pathofproject/config.yaml",
+ ["/pathofvideos/pigeon.mp4"],
+ copy_videos=False,
+ coords=None,
+ extract_frames=False
+)
```
The `deeplabcut.extract_outlier_frames` function will check for outliers and ask your feedback on whether to extract these outliers frames.
```python
-deeplabcut.extract_outlier_frames('/pathofproject/config.yaml', ['/pathofvideos/pigeon.mp4'], automatic=True)
+deeplabcut.analyze_videos(
+ "/pathofproject/config.yaml",
+ ["/pathofvideos/pigeon.mp4"]
+)
+deeplabcut.extract_outlier_frames(
+ "/pathofproject/config.yaml",
+ ["/pathofvideos/pigeon.mp4"],
+ automatic=True
+)
```
The `deeplabcut.refine_labels` function starts the GUI which allows you to refine the outlier frames manually.
You should load the outlier frames directory and corresponding `.h5` file from the previous model.
@@ -130,9 +169,9 @@ It will ask you to define the `likelihood` threshold: labels under the threshold
After refining, you should combine these data with your previous model's data set and create a new training data set.
```python
-deeplabcut.refine_labels('/pathofproject/config.yaml')
-deeplabcut.merge_datasets('/pathofproject/config.yaml')
-deeplabcut.create_training_dataset('/pathofproject/config.yaml')
+deeplabcut.refine_labels("/pathofproject/config.yaml")
+deeplabcut.merge_datasets("/pathofproject/config.yaml")
+deeplabcut.create_training_dataset("/pathofproject/config.yaml")
```
Before starting the training of your model, there is one last step left: editing the `init_weights` parameter in your `pose_cfg.yaml` file.
Go to your project and check the latest snapshot (e.g., `snapshot-610000`) of your model in `dlc-models/train` directory.
@@ -142,7 +181,7 @@ Edit the value of the `init_weights` key in the `pose_cfg.yaml` file and start t
`init_weights: pathofyourproject\dlc-models\iteration-0\DLCFeb31-trainset95shuffle1\train\snapshot-610000`
```python
-deeplabcut.train_network('/pathofproject/config.yaml', shuffle=1, saveiters=25000)
+deeplabcut.train_network("/pathofproject/config.yaml", shuffle=1, saveiters=25000)
```
```{hint}
Check this video for model refining!
diff --git a/docs/recipes/installTips.md b/docs/recipes/installTips.md
index 3864af7815..ab4565a880 100644
--- a/docs/recipes/installTips.md
+++ b/docs/recipes/installTips.md
@@ -3,24 +3,24 @@
## How to use the latest updates directly from GitHub
-We often update the master deeplabcut code base on github, and then ~1 a month we push out a stable release on pypi. This is what most users turn to on a daily basis (i.e. pypi is where you get your `pip install deeplabcut` code from! But, sometimes we add things to the repo that are not yet integrated, or you might want to edit the code yourself. Here, we show you how to do this.
+We often update the master deeplabcut code base on GitHub, and then ~1 a month we push out a stable release on pypi. This is what most users turn to on a daily basis (i.e. pypi is where you get your `pip install deeplabcut` code from! But, sometimes we add things to the repo that are not yet integrated, or you might want to edit the code yourself. Here, we show you how to do this.
### Method 1:
-If you want to *use* the latest, you can use pip and add the specific tags, such as `tf` or `gui`, etc. by modifying and running:
+If you want to *use* the latest, you can use pip and add the specific tags, such as `gui`, etc. by modifying and running:
```
-pip install --upgrade 'git+https://github.com/deeplabcut/deeplabcut.git#egg=deeplabcut[tf]'
+pip install --upgrade 'git+https://github.com/deeplabcut/deeplabcut.git#egg=deeplabcut[gui]'
```
which will download and update deeplabcut, and any dependencies that don't match the new version. If you want to force upgrade all of the dependencies to the latest available versions, too, then use the additional `--upgrade-strategy eager`, i.e.:
```
-pip install --upgrade --upgrade-strategy eager 'git+https://github.com/deeplabcut/deeplabcut.git#egg=deeplabcut[tf,gui]'
+pip install --upgrade --upgrade-strategy eager 'git+https://github.com/deeplabcut/deeplabcut.git#egg=deeplabcut[gui]'
```
### Method 2:
-If you want to be able to *edi* the source code of DeepLabCut, i.e., maybe add a feature or fix a 🐛, then you need to "clone" the source code:
+If you want to be able to *edit* the source code of DeepLabCut, i.e., maybe add a feature or fix a 🐛, then you need to "clone" the source code:
**Step 1:**
@@ -54,6 +54,10 @@ If you make changes, you can also then utilize our test scripts. Run the desired
i.e., for example:
```
+# Testing with the PyTorch engine
+python testscript_pytorch_multi_animal.py
+
+# Testing with the TensorFlow engine
python testscript_multianimal.py
```
@@ -243,6 +247,7 @@ Share images, automate workflows, and more with a free Docker ID:
For more examples and ideas, visit:
https://docs.docker.com/get-started/
```
+
### Next, Anaconda!
Click here to get the ubuntu/linux package: https://www.anaconda.com/products/individual#linux
@@ -283,6 +288,11 @@ Follow prompts!
## Troubleshooting: Note, if you get a failed build due to wxPython (note, this does not happen on Ubuntu 18, 16, etc), i.e.:
+```{warning}
+DeepLabCut no longer uses `wxpython` for its GUI - if you're getting such an error,
+you're likely installing an old version of DeepLabCut.
+```
+
```python
ERROR: Command errored out with exit status 1: /home/mackenzie/anaconda3/envs/DLC-GPU/bin/python -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-0jsmkrr1/wxpython_aeff462b2060421a9cf65df55f63a126/setup.py'"'"'; __file__='"'"'/tmp/pip-install-0jsmkrr1/wxpython_aeff462b2060421a9cf65df55f63a126/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-pzy9q5u2/install-record.txt --single-version-externally-managed --compile --install-headers /home/mackenzie/anaconda3/envs/DLC-GPU/include/python3.7m/wxpython Check the logs for full command output.
@@ -313,11 +323,11 @@ Activate! `conda activate DEEPLABCUT` and then run: `conda install -c conda-forg
Then run `python -m deeplabcut` which launches the DLC GUI.
-## DeepLabCut MacOS M1 and M2 chip installation environment instructions:
-
-This only assumes you have anaconda installed.
+## DeepLabCut MacOS M-chip installation environment instructions:
-Use the `DEEPLABCUT_M1.yaml` conda file if you have an Macbok with an M1 or M2 chip, and follow these steps:
+This only assumes you have anaconda installed. Use the `DEEPLABCUT_M1.yaml` conda file
+if you have a newer MacBook (with an M1, M2, M3, M4 chip or more later), and follow
+these steps:
(1) git clone the deeplabcut cut repo:
@@ -330,17 +340,24 @@ git clone https://github.com/DeepLabCut/DeepLabCut.git
(3) Then, run:
```bash
-conda env create -f DEEPLABCUT_M1.yaml
+conda env create -f DEEPLABCUT.yaml
```
(4) Finally, activate your environment and to launch DLC with the GUI
```bash
-conda activate DEEPLABCUT_M1
+conda activate DEEPLABCUT
python -m deeplabcut
```
-The GUI will open. Of course, you can also run DeepLabCut in headless mode.
+The GUI will open. Of course, you can also run DeepLabCut in headless mode.
+
+If **you want to use the TensorFlow engine**, you'll need to install the `apple_mchips`
+extra with DeepLabCut. You can do so by running:
+
+```bash
+pip install deeplabcut[apple_mchips]
+```
## How to confirm that your GPU is being used by DeepLabCut
@@ -364,7 +381,8 @@ During training and analysis steps, DeepLabCut does not use the GPU processor he
(5) If you don't see activity there during training, then your GPU is likely not installed correctly for DeepLabCut. Return to the installation instructions, and be sure you installed CUDA 11+, and ran `conda install cudnn -c conda-forge` after installing DeepLabCut.
-## How to install DeepLabCut for Intel and AMD GPUs on Windows
+## How to install DeepLabCut for Intel and AMD GPUs on Windows for the TensorFlow engine
+
If you are on Windows 10/11 and have a DirectX 12 compatible GPU from any vendor (AMD, Intel, or Nvidia), you utilise GPU acceleration for inference, with an installation that is consistent between devices. This method uses [Tensorflow-directml](https://github.com/microsoft/tensorflow-directml) which uses DirectML instead of Cuda for ML training and inference.
To check the DirectX version of your installed GPU, type in dxdiag into windows search and select the run command. In system information, the bottom item of the list shows your DirectX version. In addition to this ensure your standard GPU drivers are up-to-date. Updating drivers by any official means (Nvidia Geforce experience, AMD radeon software, direct from the vendor website) is fine.
diff --git a/docs/recipes/nn.md b/docs/recipes/nn.md
index 75b44e315f..39ff354938 100644
--- a/docs/recipes/nn.md
+++ b/docs/recipes/nn.md
@@ -1,15 +1,16 @@
+(tf-training-tips-and-tricks)=
# Model training tips & tricks
-## Limiting a GPU's memory consumption
+## TensorFlow Engine: Limiting a GPU's memory consumption
-All GPU memory is allocated to training by default, preventing
+With TensorFlow, all GPU memory is allocated to training by default, preventing
other Tensorflow processes from being run on the same machine.
-A flexible solution to limiting memory usage is to call `deeplabcut.train(..., allow_growth=True)`,
-which dynamically grows the GPU memory region as it is needed.
-Another, stricter option is to explicitly cap GPU usage to only a fraction
-of the available memory. For example, allocating a maximum of 1/4 of the total
-memory could be done as follows:
+A flexible solution to limiting memory usage is to call
+`deeplabcut.train(..., allow_growth=True)`, which dynamically grows the GPU memory
+region as it is needed. Another, stricter option is to explicitly cap GPU usage to only
+a fraction of the available memory. For example, allocating a maximum of 1/4 of the
+total memory could be done as follows:
```python
import tensorflow as tf
@@ -18,6 +19,7 @@ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
```
+(tf-custom-image-augmentation)=
## Using custom image augmentation
Image augmentation is the process of artificially expanding the training set
@@ -25,89 +27,104 @@ by applying various transformations to images (e.g., rotation or rescaling)
in order to make models more robust and more accurate (read our
[primer](https://www.sciencedirect.com/science/article/pii/S0896627320307170) for
more information). Although data augmentation is automatically accomplished
-by DeepLabCut, default values (see the augmentation variables in the
-[default pose_cfg.yaml](https://github.com/DeepLabCut/DeepLabCut/blob/master/deeplabcut/pose_cfg.yaml#L23-L74) file)
-can be readily overwritten prior to training.
+by DeepLabCut, default values can be readily overwritten prior to training. See the
+augmentation variables defined in the:
-Another option we discuss is a different data-efficient approach based on a method called active learning. See this [this blog post](https://github.com/DeepLabCut/DeepLabCut/blob/master/docs/recipes/nn.md#using-custom-image-augmentation) for further details.
+- PyTorch Engine: [docs for the `pytorch_config.yaml` file](dlc3-pytorch-config)
+- TensorFlow Engine: [default pose_cfg.yaml file](
+https://github.com/DeepLabCut/DeepLabCut/blob/main/deeplabcut/pose_cfg.yaml#L23-L74)
-When you `create_training_dataset` [you have several options](https://github.com/DeepLabCut/DeepLabCut/wiki/DOCSTRINGS#create_training_dataset) on what types of augmentation to use.
-```python
-deeplabcut.create_training_dataset(configpath, augmenter_type='imgaug')
-```
-
-When you do this (i.e. pass `augmenter_type`) what underlying files you are calling are these:
-https://github.com/DeepLabCut/DeepLabCut/tree/master/deeplabcut/pose_estimation_tensorflow/datasets
-You can look at what types of augmentation are available to you (or edit those files to add more). Moreover, you can add more options to the pose_cfg.yaml file. Here is a simple script you can modify and run to automatically edit the correct pose_cfg.yaml to add more augmentation to the `imgaug` loader (or open it and edit yourself).
-
-But, you can add more:
+For the single-animal TensorFlow models, [you have several options](
+https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#f-create-training-dataset-s-and-selection-of-your-neural-network)
+for image augmentation when calling `create_training_dataset`
-```python
-import deeplabcut
-
-train_pose_config, _ = deeplabcut.return_train_network_path(config_path)
-augs = {
- "gaussian_noise": True,
- "elastic_transform": True,
- "rotation": 180,
- "covering": True,
- "motion_blur": True,
-}
-deeplabcut.auxiliaryfunctions.edit_config(
- train_pose_config,
- augs,
-)
-```
-An in-depth tutorial on image augmentation and training hyperparameters can be found [here](https://deeplabcut.github.io/DeepLabCut/docs/recipes/pose_cfg_file_breakdown.html).
+An in-depth tutorial on image augmentation and training hyperparameters can be found [
+here](
+https://deeplabcut.github.io/DeepLabCut/docs/recipes/pose_cfg_file_breakdown.html).
## Evaluating intermediate (and all) snapshots
-The latest snapshot stored during training may not necessarily be the one that yields the highest performance. Therefore, you should analyze ALL snapshots, and select the best. Put 'all' in the snapshots section of the config.yaml to do this.
+The latest snapshot stored during training may not necessarily be the one that yields
+the highest performance. Therefore, you should analyze ALL snapshots, and select the
+best. Put 'all' in the snapshots section of the `config.yaml` to do this.
+(what-neural-network-should-i-use)=
## What neural network should I use? (Trade offs, speed performance, and considerations)
-### With the release of even more network options, you now have to decide what to use! This additionally flexibility is hopefully helpful, but we want to give you some guidance on where to start.
+You always select the network type when you create a training data set: i.e., standard
+dlc: `deeplabcut.create_training_dataset(config, net_type=resnet_50)` , or maDLC:
+`deeplabcut.create_multianimaltraining_dataset(config, net_type=dlcrnet_ms5)`. There is
+nothing else you should change.
+### PyTorch Engine
-**TL;DR - your best performance for most everything is ResNet-50; MobileNetV2-1 is much faster, needs less memory on your GPU to train and nearly as accurate.**
+The different architectures available are described in the [PyTorch model architectures
+](dlc3-architectures) page.
-You always select the network type when you create a training data set: i.e., standard dlc: `deeplabcut.create_training_dataset(config, net_type=resnet_50)` , or maDLC: `deeplabcut.create_multianimaltraining_dataset(config, net_type=dlcrnet_ms5)`. There is nothing else you should change.
+### TensorFlow Engine
+With the release of even more network options, you now have to decide what to use! This
+additionally flexibility is hopefully helpful, but we want to give you some guidance on
+where to start.
+
+**TL;DR - your best performance for most everything is ResNet-50; MobileNetV2-1 is much
+faster, needs less memory on your GPU to train and nearly as accurate.**
***
-## ResNets:
+#### ResNets:
-In Mathis et al. 2018 we benchmarked three networks: **ResNet-50, ResNet-101, and ResNet-101ws**. For ALL lab applications, ResNet-50 was enough. For all the demo videos on [www.deeplabcut.org](http://www.mousemotorlab.org/deeplabcut) the backbones are ResNet-50's. Thus, we recommend making this your go-to workhorse for data analysis. Here is a figure from the paper, see panel "B" (they are all within a few pixels of each other on the open-field dataset):
+In Mathis et al. 2018 we benchmarked three networks: **ResNet-50, ResNet-101, and
+ResNet-101ws**. For ALL lab applications, ResNet-50 was enough. For all the demo videos
+on [www.deeplabcut.org](http://www.mousemotorlab.org/deeplabcut) the backbones are
+ResNet-50's. Thus, we recommend making this your go-to workhorse for data analysis. Here
+is a figure from the paper, see panel "B" (they are all within a few pixels of each
+other on the open-field dataset):
-This is also one of the main result figures, generated with ResNet-50. BLUE is training - RED is testing - BLACK is our best human-level performance, and 10 pixels is the width of the mouse nose -so anything under that is good performance for us on this task!
+This is also one of the main result figures, generated with ResNet-50. BLUE is
+training - RED is testing - BLACK is our best human-level performance, and 10 pixels is
+the width - of the mouse nose -so anything under that is good performance for us on this
+task!
+
-Here are also some speed stats for analyzing videos with ResNet-50, see https://www.biorxiv.org/content/early/2018/10/30/457242 for more details:
+Here are also some speed stats for analyzing videos with ResNet-50, see
+https://www.biorxiv.org/content/early/2018/10/30/457242 for more details:
-**So, why use a ResNet-101 or even 152?** if you have a much more challenging problem, like multiple humans dancing, this is a good option. You should then also set `intermediate_supervision=True` for best performance in the pose_config.yaml of that shuffle folder ( before you train). Note, for ResNet-50 this does NOT help, and can hurt.
+**So, why use a ResNet-101 or even 152?** if you have a much more challenging problem,
+like multiple humans dancing, this is a good option. You should then also set
+`intermediate_supervision=True` for best performance in the `pose_config.yaml` of that
+shuffle folder (before you train). Note, for ResNet-50 this does NOT help, and can
+hurt.
-## When should I use a MobileNet?
+#### When should I use a MobileNet?
-MobileNets are fast to run, fast to train, more memory efficient, and faster for analysis (inference) - e.g. on CPUs they are 4 times faster, on GPUs up to 2x! So, if you don't have a GPU (or a GPU with little memory), and don't want to use Google COLAB, etc, then these are a great starting point.
+MobileNets are fast to run, fast to train, more memory efficient, and faster for
+analysis (inference) - e.g. on CPUs they are 4 times faster, on GPUs up to 2x! So, if
+you don't have a GPU (or a GPU with little memory), and don't want to use Google COLAB,
+etc, then these are a great starting point.
-They are smaller/shallower networks though, so you don't want to be pushing in very large images. So, be sure to use `deeplabcut.DownSampleVideo` on your data (which is frankly never a bad idea).
+They are smaller/shallower networks though, so you don't want to be pushing in very
+large images. So, be sure to use `deeplabcut.DownSampleVideo` on your data (which is
+frankly never a bad idea).
-Additionally, these are good options for running on "live" videos, i.e. if you want to give real-time feedback in an experiment, you can run a video around a smaller cropped area, and run this rather fast!
+Additionally, these are good options for running on "live" videos, i.e. if you want to
+give real-time feedback in an experiment, you can run a video around a smaller cropped
+area, and run this rather fast!
**So, how fast are they?**
-Here are comparisons of 4 MobileNetV2 variants to ResNet-50 and ResNet-101 (darkest red):
-read more here: https://arxiv.org/abs/1909.11229
+Here are comparisons of 4 MobileNetV2 variants to ResNet-50 and ResNet-101 (darkest
+red - read more here: https://arxiv.org/abs/1909.11229)
@@ -117,14 +134,26 @@ read more here: https://arxiv.org/abs/1909.11229
-## When should I use an EfficientNet?
+#### When should I use an EfficientNet?
-Built with inverse residual blocks like MobileNets, but more powerful than ResNets, due to optimal depth/width/resolution scaling, [EfficientNet](https://arxiv.org/abs/1905.11946) are an excellent choice if you want speed and performance. They do require more careful handling though! Especially for small datasets, you will need to tune the batch size and learning rates. So, we suggest these for more advanced users, or those willing to run experiments to find the best settings. Here is the speed comparison, and for performance see our latest work at: http://horse10.deeplabcut.org
+Built with inverse residual blocks like MobileNets, but more powerful than ResNets, due
+to optimal depth/width/resolution scaling, [EfficientNet](
+https://arxiv.org/abs/1905.11946) are an excellent choice if you want speed and
+performance. They do require more careful handling though! Especially for small
+datasets, you will need to tune the batch size and learning rates. So, we suggest these
+for more advanced users, or those willing to run experiments to find the best settings.
+Here is the speed comparison, and for performance see our latest work at:
+http://horse10.deeplabcut.org
-## How can I compare them?
+#### How can I compare them?
-Great question! So, the best way to do this is to use the **same** test/train split (that is generated in create_training_dataset) with different models. Here, as of 2.1+, we have a **new** function that lets you do this easily. Instead of using `create_training_dataset` you will run `create_training_model_comparison` (see the docstrings by `deeplabcut.create_training_model_comparison?` or run the Project Manager GUI - `deeplabcut.launch_dlc()`- for assistance.
+Great question! So, the best way to do this is to use the **same** test/train split (
+that is generated in create_training_dataset) with different models. Here, as of 2.1+,
+we have a **new** function that lets you do this easily. Instead of using
+`create_training_dataset` you will run `create_training_model_comparison` (see the
+docstrings by `deeplabcut.create_training_model_comparison?` or run the Project Manager
+GUI - `deeplabcut.launch_dlc()`- for assistance.
diff --git a/docs/recipes/pose_cfg_file_breakdown.md b/docs/recipes/pose_cfg_file_breakdown.md
index 2e980edc25..2f79ac28d3 100644
--- a/docs/recipes/pose_cfg_file_breakdown.md
+++ b/docs/recipes/pose_cfg_file_breakdown.md
@@ -1,5 +1,10 @@
# The `pose_cfg.yaml` Guideline Handbook
+::::{warning}
+The following is specific to Tensorflow-based models. To read the equivalent explanations for Pytorch-based models,
+click [here](dlc3-pytorch-config)
+::::
+
👋 Hello! Mabuhay! Hola! This recipe was written by the [2023 DLC AI Residents](https://www.deeplabcutairesidency.org/)!
When you train, evaluate, and run inference with a neural network there are hyperparatmeters you must consider. While DLC attempts to set the "globally good for everyone" parameters, you might want to change them. Therefore, in this recipe we will review the pose config parameters related to neural network models' and the related data augmentation!
diff --git a/docs/recipes/publishing_notebooks_into_the_DLC_main_cookbook.md b/docs/recipes/publishing_notebooks_into_the_DLC_main_cookbook.md
index 31710c3a2d..4a7e47b856 100644
--- a/docs/recipes/publishing_notebooks_into_the_DLC_main_cookbook.md
+++ b/docs/recipes/publishing_notebooks_into_the_DLC_main_cookbook.md
@@ -133,7 +133,7 @@ This command installs DeepLabCut along with the dependencies required to build t
14. **🎉PR Approval:🎉** Once your PR is approved, the maintainers will merge it into the main repository. Your notebook will then be a part of the DeepLabCut Jupyter book! Yay!
-Remember to always check the [DLC CONTRIBUTING guidelines](https://github.com/DeepLabCut/DeepLabCut/blob/main/CONTRIBUTING.md).
+Remember to always check the [DLC contributing guidelines](https://github.com/DeepLabCut/DeepLabCut/blob/main/CONTRIBUTING.md).
## Wrap-Up 🎉
diff --git a/docs/roadmap.md b/docs/roadmap.md
index 905a95ae6f..b303754e37 100644
--- a/docs/roadmap.md
+++ b/docs/roadmap.md
@@ -2,7 +2,7 @@
## A development roadmap for DeepLabCut
-:loudspeaker: :hourglass_flowing_sand: :construction:
+📢 ⏳ 🚧
**General Enhancements:**
- [ ] DeepLabCut PyTorch & Model Zoo --> DLC 3.0 🔥
@@ -34,8 +34,8 @@
- [X] DeepLabCut-live! published in eLife
**DeepLabCut Model Zoo: a collection of pretrained models for plug-in-play DLC and community crowd-sourcing.**
-- [X] BETA release with 2.1.8b0: http://www.mousemotorlab.org/dlc-modelzoo
-- [X] full release with 2.1.8.1 http://www.mousemotorlab.org/dlc-modelzoo
+- [X] BETA release with 2.1.8b0: https://www.mackenziemathislab.org/deeplabcut
+- [X] full release with 2.1.8.1 https://www.mackenziemathislab.org/deeplabcut
- [X] Manuscript forthcoming! --> see arXiv https://arxiv.org/abs/2203.07436
- [X] new models added; horse, cheetah
- [X] TopView_Mouse model
diff --git a/docs/standardDeepLabCut_UserGuide.md b/docs/standardDeepLabCut_UserGuide.md
index 39188b3d48..0b32ccec65 100644
--- a/docs/standardDeepLabCut_UserGuide.md
+++ b/docs/standardDeepLabCut_UserGuide.md
@@ -1,7 +1,8 @@
(single-animal-userguide)=
# DeepLabCut User Guide (for single animal projects)
-This document covers single/standard DeepLabCut use. If you have a complicated multi-animal scenario (i.e., they look the same), then please see our [maDLC user guide](multi-animal-userguide).
+This document covers single/standard DeepLabCut use. If you have a complicated multi-animal scenario (i.e., they look
+the same), then please see our [maDLC user guide](multi-animal-userguide).
To get started, you can use the GUI, or the terminal. See below.
@@ -11,7 +12,12 @@ To get started, you can use the GUI, or the terminal. See below.
**GUI:**
-To begin, navigate to Aanaconda Prompt Terminal and right-click to "open as admin "(Windows), or simply launch "Terminal" (unix/MacOS) on your computer. We assume you have DeepLabCut installed (if not, see Install docs!). Next, launch your conda env (i.e., for example `conda activate DEEPLABCUT`). Then, simply run ``python -m deeplabcut``. The below functions are available to you in an easy-to-use graphical user interface. While most functionality is available, advanced users might want the additional flexibility that command line interface offers. Read more below.
+To begin, navigate to Anaconda Prompt Terminal and right-click to "open as admin "(Windows), or simply launch
+"Terminal" (unix/MacOS) on your computer. We assume you have DeepLabCut installed (if not, see
+[install docs](how-to-install)!). Next, launch your conda env (i.e., for example `conda activate DEEPLABCUT`). Then,
+simply run `python -m deeplabcut`. The below functions are available to you in an easy-to-use graphical user interface.
+While most functionality is available, advanced users might want the additional flexibility that command line interface
+offers. Read more below.
```{Hint}
🚨 If you use Windows, please always open the terminal with administrator privileges! Right click, and "run as administrator".
```
@@ -20,11 +26,19 @@ To begin, navigate to Aanaconda Prompt Terminal and right-click to "open as admi
-As a reminder, the core functions are described in our [Nature Protocols](https://www.nature.com/articles/s41596-019-0176-0) paper (published at the time of 2.0.6). Additional functions and features are continually added to the package. Thus, we recommend you read over the protocol and then please look at the following documentation and the doctrings. Thanks for using DeepLabCut!
+As a reminder, the core functions are described in our
+[Nature Protocols paper](https://www.nature.com/articles/s41596-019-0176-0) (published at the time of 2.0.6).
+Additional functions and features are continually added to the package. Thus, we recommend you read over the protocol
+and then please look at the following documentation and the doctrings. Thanks for using DeepLabCut!
## DeepLabCut in the Terminal/Command line interface:
-To begin, navigate to Aanaconda Prompt Terminal and right-click to "open as admin "(Windows), or simply launch "Terminal" (unix/MacOS) on your computer. We assume you have DeepLabCut installed (if not, see Install docs!). Next, launch your conda env (i.e., for example `conda activate DEEPLABCUT`) and then type `ipython`. Then type `import deeplabcut`.
+To begin, navigate to Anaconda Prompt Terminal and right-click to "open as admin "(Windows), or simply launch
+"Terminal" (unix/MacOS) on your computer. We assume you have DeepLabCut installed (if not, see Install docs!). Next,
+launch your conda env (i.e., for example `conda activate DEEPLABCUT`) and then type `ipython`. Then type:
+```python
+import deeplabcut
+```
```{Hint}
🚨 If you use Windows, please always open the terminal with administrator privileges! Right click, and "run as administrator".
@@ -32,48 +46,89 @@ To begin, navigate to Aanaconda Prompt Terminal and right-click to "open as admi
### (A) Create a New Project
-The function **create\_new\_project** creates a new project directory, required subdirectories, and a basic project configuration file. Each project is identified by the name of the project (e.g. Reaching), name of the experimenter (e.g. YourName), as well as the date at creation.
+The function `create_new_project` creates a new project directory, required subdirectories, and a basic project
+configuration file. Each project is identified by the name of the project (e.g. Reaching), name of the experimenter
+(e.g. YourName), as well as the date at creation.
-Thus, this function requires the user to input the name of the project, the name of the experimenter, and the full path of the videos that are (initially) used to create the training dataset.
+Thus, this function requires the user to input the name of the project, the name of the experimenter, and the full
+path of the videos that are (initially) used to create the training dataset.
-Optional arguments specify the working directory, where the project directory will be created, and if the user wants to copy the videos (to the project directory). If the optional argument working\_directory is unspecified, the project directory is created in the current working directory, and if copy\_videos is unspecified symbolic links for the videos are created in the videos directory. Each symbolic link creates a reference to a video and thus eliminates the need to copy the entire video to the video directory (if the videos remain at the original location).
+Optional arguments specify the working directory, where the project directory will be created, and if the user wants
+to copy the videos (to the project directory). If the optional argument `working_directory` is unspecified, the
+project directory is created in the current working directory, and if `copy_videos` is unspecified symbolic links
+for the videos are created in the videos directory. Each symbolic link creates a reference to a video and thus
+eliminates the need to copy the entire video to the video directory (if the videos remain at the original location).
```python
-deeplabcut.create_new_project('Name of the project', 'Name of the experimenter', ['Full path of video 1', 'Full path of video2', 'Full path of video3'], working_directory='Full path of the working directory', copy_videos=True/False, multianimal=True/False)
+deeplabcut.create_new_project(
+ "Name of the project",
+ "Name of the experimenter",
+ ["Full path of video 1", "Full path of video2", "Full path of video3"],
+ working_directory="Full path of the working directory",
+ copy_videos=True/False,
+ multianimal=False
+)
```
**Important path formatting note**
-Windows users, you must input paths as: ``r'C:\Users\computername\Videos\reachingvideo1.avi' `` or
-
-`` 'C:\\Users\\computername\\Videos\\reachingvideo1.avi'``
-
- (TIP: you can also place ``config_path`` in front of ``deeplabcut.create_new_project`` to create a variable that holds the path to the config.yaml file, i.e. ``config_path=deeplabcut.create_new_project(...)``)
-
-
-This set of arguments will create a project directory with the name **Name of the project+name of the experimenter+date of creation of the project** in the **Working directory** and creates the symbolic links to videos in the **videos** directory. The project directory will have subdirectories: **dlc-models**, **labeled-data**, **training-datasets**, and **videos**. All the outputs generated during the course of a project will be stored in one of these subdirectories, thus allowing each project to be curated in separation from other projects. The purpose of the subdirectories is as follows:
-
-**dlc-models:** This directory contains the subdirectories *test* and *train*, each of which holds the meta information with regard to the parameters of the feature detectors in configuration files. The configuration files are YAML files, a common human-readable data serialization language. These files can be opened and edited with standard text editors. The subdirectory *train* will store checkpoints (called snapshots in TensorFlow) during training of the model. These snapshots allow the user to reload the trained model without re-training it, or to pick-up training from a particular saved checkpoint, in case the training was interrupted.
-
-**labeled-data:** This directory will store the frames used to create the training dataset. Frames from different videos are stored in separate subdirectories. Each frame has a filename related to the temporal index within the corresponding video, which allows the user to trace every frame back to its origin.
-
-**training-datasets:** This directory will contain the training dataset used to train the network and metadata, which contains information about how the training dataset was created.
-
-**videos:** Directory of video links or videos. When **copy\_videos** is set to ``False``, this directory contains symbolic links to the videos. If it is set to ``True`` then the videos will be copied to this directory. The default is ``False``. Additionally, if the user wants to add new videos to the project at any stage, the function **add\_new\_videos** can be used. This will update the list of videos in the project's configuration file.
+Windows users, you must input paths as: `r'C:\Users\computername\Videos\reachingvideo1.avi'` or
+` 'C:\\Users\\computername\\Videos\\reachingvideo1.avi'`
+
+TIP: you can also place `config_path` in front of `deeplabcut.create_new_project` to create a variable that holds
+the path to the config.yaml file, i.e. `config_path=deeplabcut.create_new_project(...)`
+
+This set of arguments will create a project directory with the name
+**
++** in the **Working directory** and
+creates the symbolic links to videos in the **videos** directory. The project directory will have subdirectories:
+**dlc-models**, **dlc-models-pytorch**, **labeled-data**, **training-datasets**, and **videos**. All the outputs
+generated during the course of a project will be stored in one of these subdirectories, thus allowing each project to be
+curated in separation from other projects. The purpose of the subdirectories is as follows:
+
+**dlc-models** and **dlc-models-pytorch** have a similar structure; the first contains files for the TensorFlow engine
+while the second contains files for the PyTorch engine. At the top level in these directories, there are directories
+referring to different iterations of label refinement (see below): **iteration-0**, **iteration-1**, etc.
+The iteration directories store shuffle directories, where each shuffle directory stores model data related to a
+particular experiment: trained and tested on a particular training and testing sets, and with a particular model
+architecture. Each shuffle directory contains the subdirectories *test* and *train*, each of which holds the meta
+information with regard to the parameters of the feature detectors in configuration files. The configuration files are
+YAML files, a common human-readable data serialization language. These files can be opened and edited with standard text
+editors. The subdirectory *train* will store checkpoints (called snapshots) during training of the model. These
+snapshots allow the user to reload the trained model without re-training it, or to pick-up training from a particular
+saved checkpoint, in case the training was interrupted.
+
+**labeled-data:** This directory will store the frames used to create the training dataset. Frames from different videos
+are stored in separate subdirectories. Each frame has a filename related to the temporal index within the corresponding
+video, which allows the user to trace every frame back to its origin.
+
+**training-datasets:** This directory will contain the training dataset used to train the network and metadata, which
+contains information about how the training dataset was created.
+
+**videos:** Directory of video links or videos. When **copy\_videos** is set to `False`, this directory contains
+symbolic links to the videos. If it is set to `True` then the videos will be copied to this directory. The default is
+`False`. Additionally, if the user wants to add new videos to the project at any stage, the function
+**add\_new\_videos** can be used. This will update the list of videos in the project's configuration file.
```python
-deeplabcut.add_new_videos('Full path of the project configuration file*', ['full path of video 4', 'full path of video 5'], copy_videos=True/False)
+deeplabcut.add_new_videos(
+ "Full path of the project configuration file*",
+ ["full path of video 4", "full path of video 5"],
+ copy_videos=True/False
+)
```
-*Please note, *Full path of the project configuration file* will be referenced as ``config_path`` throughout this protocol.
+*Please note, *Full path of the project configuration file* will be referenced as `config_path` throughout this
+protocol.
-The project directory also contains the main configuration file called *config.yaml*. The *config.yaml* file contains many important parameters of the project. A complete list of parameters including their description can be found in Box1.
+The project directory also contains the main configuration file called *config.yaml*. The *config.yaml* file contains
+many important parameters of the project. A complete list of parameters including their description can be found in
+Box1.
-The ``create_new_project`` step writes the following parameters to the configuration file: *Task*, *scorer*, *date*, *project\_path* as well as a list of videos *video\_sets*. The first three parameters should **not** be changed. The list of videos can be changed by adding new videos or manually removing videos.
+The `create_new_project` step writes the following parameters to the configuration file: *Task*, *scorer*, *date*,
+*project\_path* as well as a list of videos *video\_sets*. The first three parameters should **not** be changed. The
+list of videos can be changed by adding new videos or manually removing videos.
-
-
-
+
#### API Docs
````{admonition} Click the button to see API Docs
@@ -87,34 +142,61 @@ The ``create_new_project`` step writes the following parameters to the configura
-Next, open the **config.yaml** file, which was created during **create\_new\_project**. You can edit this file in any text editor. Familiarize yourself with the meaning of the parameters (Box 1). You can edit various parameters, in particular you **must add the list of *bodyparts* (or points of interest)** that you want to track. You can also set the *colormap* here that is used for all downstream steps (can also be edited at anytime), like labeling GUIs, videos, etc. Here any [matplotlib colormaps](https://matplotlib.org/tutorials/colors/colormaps.html) will do!
+Next, open the **config.yaml** file, which was created during **create\_new\_project**. You can edit this file in any
+text editor. Familiarize yourself with the meaning of the parameters (Box 1). You can edit various parameters, in
+particular you **must add the list of *bodyparts* (or points of interest)** that you want to track. You can also set the
+*colormap* here that is used for all downstream steps (can also be edited at anytime), like labeling GUIs, videos, etc.
+Here any [matplotlib colormaps](https://matplotlib.org/tutorials/colors/colormaps.html) will do!
Please DO NOT have spaces in the names of bodyparts.
**bodyparts:** are the bodyparts of each individual (in the above list).
- ### (C) Data Selection (extract frames)
-
-**CRITICAL:** A good training dataset should consist of a sufficient number of frames that capture the breadth of the behavior. This ideally implies to select the frames from different (behavioral) sessions, different lighting and different animals, if those vary substantially (to train an invariant, robust feature detector). Thus for creating a robust network that you can reuse in the laboratory, a good training dataset should reflect the diversity of the behavior with respect to postures, luminance conditions, background conditions, animal identities,etc. of the data that will be analyzed. For the simple lab behaviors comprising mouse reaching, open-field behavior and fly behavior, 100−200 frames gave good results [Mathis et al, 2018](https://www.nature.com/articles/s41593-018-0209-y). However, depending on the required accuracy, the nature of behavior, the video quality (e.g. motion blur, bad lighting) and the context, more or less frames might be necessary to create a good network. Ultimately, in order to scale up the analysis to large collections of videos with perhaps unexpected conditions, one can also refine the data set in an adaptive way (see refinement below).
-
-The function `extract_frames` extracts frames from all the videos in the project configuration file in order to create a training dataset. The extracted frames from all the videos are stored in a separate subdirectory named after the video file’s name under the ‘labeled-data’. This function also has various parameters that might be useful based on the user’s need.
+ ### (C) Select Frames to Label
+
+**CRITICAL:** A good training dataset should consist of a sufficient number of frames that capture the breadth of the
+behavior. This ideally implies to select the frames from different (behavioral) sessions, different lighting and
+different animals, if those vary substantially (to train an invariant, robust feature detector). Thus for creating a
+robust network that you can reuse in the laboratory, a good training dataset should reflect the diversity of the
+behavior with respect to postures, luminance conditions, background conditions, animal identities,etc. of the data that
+will be analyzed. For the simple lab behaviors comprising mouse reaching, open-field behavior and fly behavior, 100−200
+frames gave good results [Mathis et al, 2018](https://www.nature.com/articles/s41593-018-0209-y). However, depending on
+the required accuracy, the nature of behavior, the video quality (e.g. motion blur, bad lighting) and the context, more
+or less frames might be necessary to create a good network. Ultimately, in order to scale up the analysis to large
+collections of videos with perhaps unexpected conditions, one can also refine the data set in an adaptive way (see
+refinement below).
+
+The function `extract_frames` extracts frames from all the videos in the project configuration file in order to create
+a training dataset. The extracted frames from all the videos are stored in a separate subdirectory named after the video
+file’s name under the ‘labeled-data’. This function also has various parameters that might be useful based on the user’s
+need.
```python
-deeplabcut.extract_frames(config_path, mode='automatic/manual', algo='uniform/kmeans', userfeedback=False, crop=True/False)
+deeplabcut.extract_frames(
+ config_path,
+ mode="automatic/manual",
+ algo="uniform/kmeans",
+ crop=True/False,
+ userfeedback=False
+)
```
**CRITICAL POINT:** It is advisable to keep the frame size small, as large frames increase the training and
inference time. The cropping parameters for each video can be provided in the config.yaml file (and see below).
-When running the function extract_frames, if the parameter crop=True, then you will be asked to draw a box within the GUI (and this is written to the config.yaml file).
-
-`userfeedback` allows the user to check which videos they wish to extract frames from. In this way, if you added more videos to the config.yaml file it does not, by default, extract frames (again) from every video. If you wish to disable this question, set `userfeedback = True`.
-
-The provided function either selects frames from the videos that are randomly sampled from a uniform distribution (uniform), by clustering based on visual appearance (k-means), or by manual selection. Random
-selection of frames works best for behaviors where the postures vary across the whole video. However, some behaviors
-might be sparse, as in the case of reaching where the reach and pull are very fast and the mouse is not moving much
-between trials (thus, we have the default set to True, as this is best for most use-cases we encounter). In such a case, the function that allows selecting frames based on k-means derived quantization would
-be useful. If the user chooses to use k-means as a method to cluster the frames, then this function downsamples the
-video and clusters the frames using k-means, where each frame is treated as a vector. Frames from different clusters
-are then selected. This procedure makes sure that the frames look different. However, on large and long videos, this
-code is slow due to computational complexity.
+When running the function extract_frames, if the parameter crop=True, then you will be asked to draw a box within the
+GUI (and this is written to the config.yaml file).
+
+`userfeedback` allows the user to specify which videos they wish to extract frames from. When set to `"True"`, a dialog
+will be initiated, where the user is asked for each video if (additional/any) frames from this video should be
+extracted. Use this, e.g. if you have already labeled some folders and want to extract data for new videos.
+
+The provided function either selects frames from the videos that are randomly sampled from a uniform distribution
+(uniform), by clustering based on visual appearance (k-means), or by manual selection. Random uniform selection of
+frames works best for behaviors where the postures vary across the whole video. However, some behaviors might be sparse,
+as in the case of reaching where the reach and pull are very fast and the mouse is not moving much between trials. In
+such a case, the function that allows selecting frames based on k-means derived quantization would be useful. If the
+user chooses to use k-means as a method to cluster the frames, then this function downsamples the video and clusters the
+frames using k-means, where each frame is treated as a vector. Frames from different clusters are then selected. This
+procedure makes sure that the frames look different. However, on large and long videos, this code is slow due to
+computational complexity.
**CRITICAL POINT:** It is advisable to extract frames from a period of the video that contains interesting
behaviors, and not extract the frames across the whole video. This can be achieved by using the start and stop
@@ -122,13 +204,16 @@ parameters in the config.yaml file. Also, the user can change the number of fram
the numframes2extract in the config.yaml file.
However, picking frames is highly dependent on the data and the behavior being studied. Therefore, it is hard to
-provide all purpose code that extracts frames to create a good training dataset for every behavior and animal. If the user feels specific frames are lacking, they can extract hand selected frames of interest using the interactive GUI
+provide all purpose code that extracts frames to create a good training dataset for every behavior and animal. If the
+user feels specific frames are lacking, they can extract hand selected frames of interest using the interactive GUI
provided along with the toolbox. This can be launched by using:
```python
-deeplabcut.extract_frames(config_path, 'manual')
+deeplabcut.extract_frames(config_path, "manual")
```
The user can use the *Load Video* button to load one of the videos in the project configuration file, use the scroll
-bar to navigate across the video and *Grab a Frame* (or a range of frames, as of version 2.0.5) to extract the frame(s). The user can also look at the extracted frames and e.g. delete frames (from the directory) that are too similar before reloading the set and then manually annotating them.
+bar to navigate across the video and *Grab a Frame* (or a range of frames, as of version 2.0.5) to extract the frame(s).
+The user can also look at the extracted frames and e.g. delete frames (from the directory) that are too similar before
+reloading the set and then manually annotating them.
@@ -144,38 +229,42 @@ bar to navigate across the video and *Grab a Frame* (or a range of frames, as of
### (D) Label Frames
-The toolbox provides a function **label_frames** which helps the user to easily label all the extracted frames using
-an interactive graphical user interface (GUI). The user should have already named the body parts to label (points of
-interest) in the project’s configuration file by providing a list. The following command invokes the labeling toolbox.
+The toolbox provides a function **label_frames** which helps the user to easily label
+all the extracted frames using an interactive graphical user interface (GUI). The user
+should have already named the bodyparts to label (points of interest) in the
+project’s configuration file by providing a list. The following command invokes the
+napari-deeplabcut labelling GUI. Checkout the [napari-deeplabcut docs](napari-gui) for
+more information about the labelling workflow.
+
```python
deeplabcut.label_frames(config_path)
```
-The user needs to use the *Load Frames* button to select the directory which stores the extracted frames from one of
-the videos. Subsequently, the user can use one of the radio buttons (top right) to select a body part to label. RIGHT click to add the label. Left click to drag the label, if needed. If you label a part accidentally, you can use the middle button on your mouse to delete! If you cannot see a body part in the frame, skip over the label! Please see the ``HELP`` button for more user instructions. This auto-advances once you labeled the first body part. You can also advance to the next frame by clicking on the RIGHT arrow on your keyboard (and go to a previous frame with LEFT arrow).
-Each label will be plotted as a dot in a unique color.
-
-The user is free to move around the body part and once satisfied with its position, can select another radio button
-(in the top right) to switch to the respective body part (it otherwise auto-advances). The user can skip a body part if it is not visible. Once all the visible body parts are labeled, then the user can use ‘Next Frame’ to load the following frame. The user needs to save the labels after all the frames from one of the videos are labeled by clicking the save button at the bottom right. Saving the labels will create a labeled dataset for each video in a hierarchical data file format (HDF) in the
-subdirectory corresponding to the particular video in **labeled-data**. You can save at any intermediate step (even without closing the GUI, just hit save) and you return to labeling a dataset by reloading it!
-
-**CRITICAL POINT:** It is advisable to **consistently label similar spots** (e.g., on a wrist that is very large, try
-to label the same location). In general, invisible or occluded points should not be labeled by the user. They can
-simply be skipped by not applying the label anywhere on the frame.
-OPTIONAL: In the event of adding more labels to the existing labeled dataset, the user need to append the new
-labels to the bodyparts in the config.yaml file. Thereafter, the user can call the function **label_frames**. As of 2.0.5+: then a box will pop up and ask the user if they wish to display all parts, or only add in the new labels. Saving the labels after all the images are labelled will append the new labels to the existing labeled dataset.
+[🎥 DEMO](https://youtu.be/hsA9IB5r73E)
HOT KEYS IN THE Labeling GUI (also see "help" in GUI):
+
```
Ctrl + C: Copy labels from previous frame.
Keyboard arrows: advance frames.
Delete key: delete label.
```
+

+**CRITICAL POINT:** It is advisable to **consistently label similar spots** (e.g., on a wrist that is very large, try
+to label the same location). In general, invisible or occluded points should not be labeled by the user. They can
+simply be skipped by not applying the label anywhere on the frame.
+OPTIONAL: In the event of adding more labels to the existing labeled dataset, the user need to append the new
+labels to the bodyparts in the config.yaml file. Thereafter, the user can call the function **label_frames**. As of
+2.0.5+: then a box will pop up and ask the user if they wish to display all parts, or only add in the new labels.
+Saving the labels after all the images are labelled will append the new labels to the existing labeled dataset.
-### (E) Check Annotated Frames
+For more information, checkout the [napari-deeplabcut docs](napari-gui) for
+more information about the labelling workflow.
+
+### (E) Check Annotated Frames
OPTIONAL: Checking if the labels were created and stored correctly is beneficial for training, since labeling
is one of the most critical parts for creating the training dataset. The DeepLabCut toolbox provides a function
@@ -184,7 +273,10 @@ is one of the most critical parts for creating the training dataset. The DeepLab
deeplabcut.check_labels(config_path, visualizeindividuals=True/False)
```
-For each video directory in labeled-data this function creates a subdirectory with **labeled** as a suffix. Those directories contain the frames plotted with the annotated body parts. The user can double check if the body parts are labeled correctly. If they are not correct, the user can reload the frames (i.e. `deeplabcut.label_frames`), move them around, and click save again.
+For each video directory in labeled-data this function creates a subdirectory with **labeled** as a suffix. Those
+directories contain the frames plotted with the annotated body parts. The user can double check if the body parts are
+labeled correctly. If they are not correct, the user can reload the frames (i.e. `deeplabcut.label_frames`), move them
+around, and click save again.
#### API Docs
````{admonition} Click the button to see API Docs
@@ -194,99 +286,228 @@ For each video directory in labeled-data this function creates a subdirectory wi
```
````
-### (F) Create Training Dataset(s) and selection of your neural network
+(create-training-dataset)=
+### (F) Create Training Dataset
-**CRITICAL POINT:** Only run this step **where** you are going to train the network. If you label on your laptop but move your project folder to Google Colab or AWS, lab server, etc, then run the step below on that platform! If you labeled on a Windows machine but train on Linux, this is fine as of 2.0.4 onwards it will be done automatically (it saves file sets as both Linux and Windows for you).
+**CRITICAL POINT:** Only run this step **where** you are going to train the network. If you label on your laptop but
+move your project folder to Google Colab or AWS, lab server, etc, then run the step below on that platform! If you
+labeled on a Windows machine but train on Linux, this is fine as of 2.0.4 onwards it will be done automatically (it
+saves file sets as both Linux and Windows for you).
-- If you move your project folder, you must only change the `project_path` (which is done automatically) in the main config.yaml file - that's it - no need to change the video paths, etc! Your project is fully portable.
+- If you move your project folder, you must only change the `project_path` (which is done automatically) in the main
+config.yaml file - that's it - no need to change the video paths, etc! Your project is fully portable.
-- Be aware you select your neural network backbone at this stage. As of DLC3+ we support PyTorch (and TensorFlow, but this will be phased out).
+- Be aware you select your neural network backbone at this stage. As of DLC3+ we support PyTorch (and TensorFlow, but
+this will be phased out).
-**OVERVIEW:** This function combines the labeled datasets from all the videos and splits them to create train and test datasets. The training data will be used to train the network, while the test data set will be used for evaluating the network. The function **create_training_dataset** performs those steps.
+**OVERVIEW:** This function combines the labeled datasets from all the videos and splits them to create train and test
+datasets. The training data will be used to train the network, while the test data set will be used for evaluating the
+network.
```python
deeplabcut.create_training_dataset(config_path)
```
-- OPTIONAL: If the user wishes to benchmark the performance of the DeepLabCut, they can create multiple training datasets by specifying an integer value to the `num_shuffles`; see the docstring for more details.
-
-within **dlc-models** called ``test`` and ``train``, and these each have a configuration file called pose_cfg.yaml.
-Specifically, the user can edit the **pose_cfg.yaml** within the **train** subdirectory before starting the training. These configuration files contain meta information with regard to the parameters of the feature detectors. Key parameters are listed in Box 2.
+- OPTIONAL: If the user wishes to benchmark the performance of the DeepLabCut, they can create multiple training
+datasets by specifying an integer value to the `num_shuffles`; see the docstring for more details.
-**CRITICAL POINT:** At this step, for **create_training_dataset** you select the network you want to use, and any additional data augmentation (beyond our defaults). You can set ``net_type`` and ``augmenter_type`` when you call the function.
+The function creates a new shuffle(s) directory in the **dlc-models-pytorch** directory
+(**dlc-models** if using Tensorflow), in the current "iteration" directory.
+The `train` and `test` directories each have a configuration file
+(**pytorch_config.yaml** in **train** and **pose_cfg.yaml** in **test** for Pytorch models,
+**pose_cfg.yaml** in **train** and **test** for Tensorflow models).
+Specifically, the user can edit the **pytorch_config.yaml** (or **pose_cfg.yaml**) within the **train** subdirectory
+before starting the training. These configuration files contain meta information with regard to the parameters
+of the feature detectors. For more information about the **pytorch_config.yaml** file, see [here](dlc3-pytorch-config)
+(for TensorFlow-based models, see key parameters
+[here](https://github.com/DeepLabCut/DeepLabCut/blob/main/deeplabcut/pose_cfg.yaml)).
-- Networks: ImageNet pre-trained networks OR SuperAnimal pre-trained networks weights will be downloaded, as you select. You can decide to do transfer-learning (recommended) or "fine-tune" both the backbone and the decoder head. We suggest seeing our [dedicated documentation on models](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/architectures.html) for more information.
+**CRITICAL POINT:** At this step, for **create_training_dataset** you select the network you want to use, and any
+additional data augmentation (beyond our defaults). You can set `net_type`, `detector_type` (if using a detector)
+and `augmenter_type` when you call the function.
+- Networks: ImageNet pre-trained networks OR SuperAnimal pre-trained networks weights will be downloaded, as you
+select. You can decide to do transfer-learning (recommended) or "fine-tune" both the backbone and the decoder head. We
+suggest seeing our [dedicated documentation on models](dlc3-architectures) for more information (
+or the [this page on selecting models](what-neural-network-should-i-use) for the TensorFlow engine).
```{Hint}
-🚨 If they do not download (you will see this downloading in the terminal), then you may not have permission to do so - be sure to open your terminal "as an admin" (This is only something we have seen with some Windows users - see the **[docs for more help!](https://deeplabcut.github.io/DeepLabCut/docs/recipes/nn.html)**).
-```
-
-**DATA AUGMENTATION:** At this stage you can also decide what type of augmentation to use. The default loaders work well for most all tasks (as shown on www.deeplabcut.org), but there are many options, more data augmentation, intermediate supervision, etc. Please look at the [**pose_cfg.yaml**](https://github.com/DeepLabCut/DeepLabCut/blob/master/deeplabcut/pose_cfg.yaml) file for a full list of parameters **you might want to change before running this step.** There are several data loaders that can be used. For example, you can use the default loader (introduced and described in the Nature Protocols paper), [TensorPack](https://github.com/tensorpack/tensorpack) for data augmentation (currently this is easiest on Linux only), or [imgaug](https://imgaug.readthedocs.io/en/latest/). We recommend `imgaug`. You can set this by passing:``` deeplabcut.create_training_dataset(config_path, augmenter_type='imgaug') ```
-
-**For TensorFlow Models:** the differences of the loaders are as follows:
-- `imgaug`: a lot of augmentation possibilities, efficient code for target map creation & batch sizes >1 supported. You can set the parameters such as the `batch_size` in the `pose_cfg.yaml` file for the model you are training. This is the recommended DEFAULT!
-- `crop_scale`: our standard DLC 2.0 introduced in Nature Protocols variant (scaling, auto-crop augmentation)
-- `tensorpack`: a lot of augmentation possibilities, multi CPU support for fast processing, target maps are created less efficiently than in imgaug, does not allow batch size>1
-- `deterministic`: only useful for testing, freezes numpy seed; otherwise like default.
-
-**For PyTorch Models:**
-- #TODO: more information coming soon; in the meantime see the docstrings!
-
-
-**MODEL COMPARISON:** You can also test several models by creating the same test/train split for different networks. You can easily do this in the Project Manager GUI, which also lets you compare PyTorch and TensorFlow models.
-
-Please also consult the [following page on selecting models]( https://deeplabcut.github.io/DeepLabCut/docs/recipes/nn.html#what-neural-network-should-i-use-trade-offs-speed-performance-and-considerations)
+🚨 If they do not download (you will see this downloading in the terminal), then you may not have permission to do
+so - be sure to open your terminal "as an admin" (This is only something we have seen with some Windows users - see
+the **[docs for more help!](tf-training-tips-and-tricks)**).
+```
+
+**DATA AUGMENTATION:** At this stage you can also decide what type of augmentation to
+use. Once you've called `create_training_dataset`, you can edit the
+[**pytorch_config.yaml**](dlc3-pytorch-config) file that was created (or for the
+TensorFlow engine, the [**pose_cfg.yaml**](
+https://github.com/DeepLabCut/DeepLabCut/blob/main/deeplabcut/pose_cfg.yaml) file).
+
+- PyTorch Engine: [Albumentations](https://albumentations.ai/docs/) is used for data
+augmentation. Look at the [**pytorch_config.yaml**](dlc3-pytorch-config) for more
+information about image augmentation options.
+- TensorFlow Engine: The default augmentation works well for most tasks (as shown on
+www.deeplabcut.org), but there are many options, more data augmentation, intermediate
+supervision, etc. Here are the available loaders:
+ - `imgaug`: a lot of augmentation possibilities, efficient code for target map creation & batch sizes >1 supported.
+ You can set the parameters such as the `batch_size` in the `pose_cfg.yaml` file for the model you are training. This
+ is the recommended default!
+ - `crop_scale`: our standard DLC 2.0 introduced in Nature Protocols variant (scaling, auto-crop augmentation)
+ - `tensorpack`: a lot of augmentation possibilities, multi CPU support for fast processing, target maps are created
+ less efficiently than in imgaug, does not allow batch size>1
+ - `deterministic`: only useful for testing, freezes numpy seed; otherwise like default.
+
+**MODEL COMPARISON**: You can also test several models by creating the same train/test
+split for different networks.
+You can easily do this in the Project Manager GUI (by selecting the "Use an existing
+data split" option), which also lets you compare PyTorch and TensorFlow models.
+
+````{versionadded} 3.0.0
+You can now create new shuffles using the same train/test split as
+existing shuffles with `create_training_dataset_from_existing_split`. This allows you to
+compare model performance (between different architectures or when using different
+training hyper-parameters) as the shuffles were trained on the same data, and evaluated
+on the same test data!
+
+Example usage - creating 3 new shuffles (with indices 10, 11 and 12) for a ResNet 50
+pose estimation model, using the same data split as was used for shuffle 0:
- See Box 2 on how to specify **which network is loaded for training (including your own network, etc):**
-
-
-
-
+```python
+deeplabcut.create_training_dataset_from_existing_split(
+ config_path,
+ from_shuffle=0,
+ shuffles=[10, 11, 12],
+ net_type="resnet_50",
+)
+```
+````
-#### API Docs for deeplabcut.create_training_dataset
-````{admonition} Click the button to see API Docs
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_dataset
:class: dropdown
```{eval-rst}
.. include:: ./api/deeplabcut.create_training_dataset.rst
```
````
-#### API Docs for deeplabcut.create_training_model_comparison
-````{admonition} Click the button to see API Docs
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_model_comparison
:class: dropdown
```{eval-rst}
.. include:: ./api/deeplabcut.create_training_model_comparison.rst
```
````
+````{admonition} Click the button to see API Docs for deeplabcut.create_training_dataset_from_existing_split
+:class: dropdown
+```{eval-rst}
+.. include:: ./api/deeplabcut.create_training_dataset_from_existing_split.rst
+```
+````
+
### (G) Train The Network
The function ‘train_network’ helps the user in training the network. It is used as follows:
```python
deeplabcut.train_network(config_path)
```
-The set of arguments in the function starts training the network for the dataset created for one specific shuffle. Note that you can change the loader (imgaug/default/etc) as well as other training parameters in the **pose_cfg.yaml** file of the model that you want to train (before you start training).
+The set of arguments in the function starts training the network for the dataset created
+for one specific shuffle. Note that you can change training parameters in the
+[**pytorch_config.yaml**](dlc3-pytorch-config) file (or **pose_cfg.yaml** for TensorFlow
+models) of the model that you want to train (before you start training).
-Example parameters that one can call:
-```python
-deeplabcut.train_network(config_path, shuffle=1, trainingsetindex=0, gputouse=None, max_snapshots_to_keep=5, autotune=False, displayiters=100, saveiters=15000, maxiters=30000, allow_growth=True)
-```
+At user specified iterations during training checkpoints are stored in the subdirectory
+*train* under the respective iteration & shuffle directory.
+
+````{admonition} Tips on training models with the PyTorch Engine
+:class: dropdown
-By default, the pretrained networks are not in the DeepLabCut toolbox (as they are around 100MB each), but they get downloaded before you train. However, if not previously downloaded, it will be downloaded and stored in a subdirectory *pre-trained* under the subdirectory *models* in *Pose_Estimation_Tensorflow* or *Pose_Estimation_PyTorch*.
-At user specified iterations during training checkpoints are stored in the subdirectory *train* under the respective iteration directory.
+Example parameters that one can call:
-If the user wishes to restart the training at a specific checkpoint they can specify the full path of the checkpoint to
-the variable ``init_weights`` in the **pose_cfg.yaml** file under the *train* subdirectory (see Box 2).
+```python
+deeplabcut.train_network(
+ config_path,
+ shuffle=1,
+ trainingsetindex=0,
+ device="cuda:0",
+ max_snapshots_to_keep=5,
+ displayiters=100,
+ save_epochs=5,
+ epochs=200,
+)
+```
+
+Pytorch models in DeepLabCut 3.0 are trained for a set number of epochs, instead of a
+maximum number of iterations (which is what was used for TensorFlow models). An epoch
+is a single pass through the training dataset, which means your model has seen each
+training image exactly once. So if you have 64 training images for your network, an
+epoch is 64 iterations with batch size 1 (or 32 iterations with batch size 2, 16 with
+batch size 4, etc.).
+
+By default, the pretrained networks are not in the DeepLabCut toolbox (as they can be
+more than 100MB), but they get downloaded automatically before you train.
+
+If the user wishes to restart the training at a specific checkpoint they can specify the
+full path of the checkpoint to the variable ``resume_training_from`` in the [
+**pytorch_config.yaml**](
+dlc3-pytorch-config) file (checkout the "Restarting Training at a Specific Checkpoint"
+section of the docs) under the *train* subdirectory.
+
+**CRITICAL POINT:** It is recommended to train the networks **until the loss plateaus**
+(depending on the dataset, model architecture and training hyper-parameters this happens
+after 100 to 250 epochs of training).
+
+The variables ``display_iters`` and ``save_epochs`` in the [**pytorch_config.yaml**](
+dlc3-pytorch-config) file allows the user to alter how often the loss is displayed
+and how often the weights are stored. We suggest saving every 5 to 25 epochs.
+````
-**CRITICAL POINT, For TensorFlow models:** it is recommended to train the ResNets or MobileNets for thousands of iterations until the loss plateaus (typically around **500,000**) if you use batch size 1. If you want to batch train, [we recommend using Adam, see more here](https://deeplabcut.github.io/DeepLabCut/docs/recipes/nn.html#using-custom-image-augmentation).
+````{admonition} Tips on training models with the TensorFlow Engine
+:class: dropdown
-**CRITICAL POINT, For PyTorch models:** PyTorch uses "epochs" not iterations. Please see our dedicated documentation that [explains how best to set the number of epochs here](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/user_guide.html). When in doubt, stick to the default! A bonus, training time is much less!
+Example parameters that one can call:
-**maDeepLabCut CRITICAL POINT:** For multi-animal projects we are using not only different and new output layers, but also new data augmentation, optimization, learning rates, and batch training defaults. Thus, please use a lower ``save_iters`` and ``maxiters``. I.e., we suggest saving every 10K-15K iterations, and only training until 50K-100K iterations. We recommend you look closely at the loss to not overfit on your data. The bonus, training time is much less!
+```python
+deeplabcut.train_network(
+ config_path,
+ shuffle=1,
+ trainingsetindex=0,
+ gputouse=None,
+ max_snapshots_to_keep=5,
+ autotune=False,
+ displayiters=100,
+ saveiters=25000,
+ maxiters=300000,
+ allow_growth=True,
+)
+```
+
+By default, the pretrained networks are not in the DeepLabCut toolbox (as they are
+around 100MB each), but they get downloaded before you train. However, if not previously
+downloaded from the TensorFlow model weights, it will be downloaded and stored in a
+subdirectory *pre-trained* under the subdirectory *models* in
+*Pose_Estimation_Tensorflow*. At user specified iterations during training checkpoints
+are stored in the subdirectory *train* under the respective iteration directory.
+
+If the user wishes to restart the training at a specific checkpoint they can specify the
+full path of the checkpoint to the variable ``init_weights`` in the **pose_cfg.yaml**
+file under the *train* subdirectory (see Box 2).
+
+**CRITICAL POINT:** It is recommended to train the networks for thousands of iterations
+until the loss plateaus (typically around **500,000**) if you use batch size 1. If you
+want to batch train, we recommend using Adam,
+[see more here](tf-custom-image-augmentation).
+
+The variables ``display_iters`` and ``save_iters`` in the **pose_cfg.yaml** file allows
+the user to alter how often the loss is displayed and how often the weights are stored.
+
+**maDeepLabCut CRITICAL POINT:** For multi-animal projects we are using not only
+different and new output layers, but also new data augmentation, optimization, learning
+rates, and batch training defaults. Thus, please use a lower ``save_iters`` and
+``maxiters``. I.e. we suggest saving every 10K-15K iterations, and only training until
+50K-100K iterations. We recommend you look closely at the loss to not overfit on your
+data. The bonus, training time is much less!!!
+````
-#### API Docs
-````{admonition} Click the button to see API Docs
+````{admonition} Click the button to see API Docs for train_network
:class: dropdown
```{eval-rst}
.. include:: ./api/deeplabcut.train_network.rst
@@ -296,36 +517,40 @@ the variable ``init_weights`` in the **pose_cfg.yaml** file under the *train* su
### (H) Evaluate the Trained Network
It is important to evaluate the performance of the trained network. This performance is measured by computing
-the mean average Euclidean error (MAE; which is proportional to the average root mean square error) between the
-manual labels and the ones predicted by DeepLabCut. The MAE is saved as a comma separated file and displayed
-for all pairs and only likely pairs (>p-cutoff). This helps to exclude, for example, occluded body parts. One of the
-strengths of DeepLabCut is that due to the probabilistic output of the scoremap, it can, if sufficiently trained, also
-reliably report if a body part is visible in a given frame. (see discussions of finger tips in reaching and the Drosophila
-legs during 3D behavior in [Mathis et al, 2018]). The evaluation results are computed by typing:
+the average root mean square error (RMSE) between the manual labels and the ones predicted by DeepLabCut.
+The RMSE is saved as a comma separated file and displayed for all pairs and only likely pairs (>p-cutoff).
+This helps to exclude, for example, occluded body parts. One of the strengths of DeepLabCut is that due to the
+probabilistic output of the scoremap, it can, if sufficiently trained, also reliably report if a body part is visible
+in a given frame. (see discussions of finger tips in reaching and the Drosophila legs during 3D behavior in
+[Mathis et al, 2018]). The evaluation results are computed by typing:
```python
deeplabcut.evaluate_network(config_path, Shuffles=[1], plotting=True)
```
-Setting ``plotting`` to true plots all the testing and training frames with the manual and predicted labels. The user
+Setting `plotting` to true plots all the testing and training frames with the manual and predicted labels. The user
should visually check the labeled test (and training) images that are created in the ‘evaluation-results’ directory.
Ideally, DeepLabCut labeled unseen (test images) according to the user’s required accuracy, and the average train
-and test errors are comparable (good generalization). What (numerically) comprises an acceptable MAE depends on
-many factors (including the size of the tracked body parts, the labeling variability, etc.). Note that the test error can
-also be larger than the training error due to human variability (in labeling, see Figure 2 in Mathis et al, Nature Neuroscience 2018).
+and test errors are comparable (good generalization). What (numerically) comprises an acceptable RMSE depends on
+many factors (including the size of the tracked body parts, the labeling variability, etc.). Note that the test error
+can also be larger than the training error due to human variability (in labeling, see Figure 2 in Mathis et al, Nature
+Neuroscience 2018).
**Optional parameters:**
-```
- Shuffles: list, optional -List of integers specifying the shuffle indices of the training dataset. The default is [1]
- plotting: bool, optional -Plots the predictions on the train and test images. The default is `False`; if provided it must be either `True` or `False`
+- `Shuffles: list, optional` - List of integers specifying the shuffle indices of the training dataset.
+The default is [1]
- show_errors: bool, optional -Display train and test errors. The default is `True`
+- `plotting: bool, optional` - Plots the predictions on the train and test images. The default is `False`;
+if provided it must be either `True` or `False`
- comparisonbodyparts: list of bodyparts, Default is all -The average error will be computed for those body parts only (Has to be a subset of the body parts).
+- `show_errors: bool, optional` - Display train and test errors. The default is `True`
- gputouse: int, optional -Natural number indicating the number of your GPU (see number in nvidia-smi). If you do not have a GPU, put None. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
-```
+- `comparisonbodyparts: list of bodyparts, Default is all` - The average error will be computed for those body parts
+only (Has to be a subset of the body parts).
+
+- `gputouse: int, optional` - Natural number indicating the number of your GPU (see number in nvidia-smi). If you do not
+have a GPU, put None. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries
The plots can be customized by editing the **config.yaml** file (i.e., the colormap, scale, marker size (dotsize), and
transparency of labels (alphavalue) can be modified). By default each body part is plotted in a different color
@@ -334,9 +559,10 @@ plotted as plus (‘+’), DeepLabCut’s predictions either as ‘.’ (for con
’x’ for (likelihood <= `pcutoff`).
The evaluation results for each shuffle of the training dataset are stored in a unique subdirectory in a newly created
-directory ‘evaluation-results’ in the project directory. The user can visually inspect if the distance between the labeled
-and the predicted body parts are acceptable. In the event of benchmarking with different shuffles of same training
-dataset, the user can provide multiple shuffle indices to evaluate the corresponding network.
+directory ‘evaluation-results-pytorch’ (‘evaluation-results’ for tensorflow models) in the project directory.
+The user can visually inspect if the distance between the labeled and the predicted body parts are acceptable.
+In the event of benchmarking with different shuffles of same training dataset, the user can provide multiple shuffle
+indices to evaluate the corresponding network.
Note that with multi-animal projects additional distance statistics aggregated over animals or bodyparts are also stored
in that directory. This aims at providing a finer quantitative evaluation of multi-animal prediction performance
before animal tracking. If the generalization is not sufficient, the user might want to:
@@ -363,24 +589,40 @@ you can drop "Indices" to run this on all training/testing images (this is slow!
```
````
-### (I) Novel Video Analysis:
+### (I) Analyze new Videos
-The trained network can be used to analyze new videos. The user needs to first choose a checkpoint with the best
-evaluation results for analyzing the videos. In this case, the user can enter the corresponding index of the checkpoint
-to the variable snapshotindex in the config.yaml file. By default, the most recent checkpoint (i.e. last) is used for
-analyzing the video. Novel/new videos **DO NOT have to be in the config file!** You can analyze new videos anytime by simply using the following line of code:
+The trained network can be used to analyze new videos. Novel/new videos **DO NOT have to be in the config file!**.
+You can analyze new videos anytime by simply using the following line of code:
```python
-deeplabcut.analyze_videos(config_path, ['fullpath/analysis/project/videos/reachingvideo1.avi'], save_as_csv=True)
+deeplabcut.analyze_videos(
+ config_path, ["fullpath/analysis/project/videos/reachingvideo1.avi"],
+ save_as_csv=True
+)
```
There are several other optional inputs, such as:
```python
-deeplabcut.analyze_videos(config_path, videos, videotype='avi', shuffle=1, trainingsetindex=0, gputouse=None, save_as_csv=False, destfolder=None, dynamic=(True, .5, 10))
-```
-The labels are stored in a [MultiIndex Pandas Array](http://pandas.pydata.org), which contains the name of the network, body part name, (x, y) label position in pixels, and the likelihood for each frame per body part. These
-arrays are stored in an efficient Hierarchical Data Format (HDF) in the same directory, where the video is stored.
-However, if the flag ``save_as_csv`` is set to ``True``, the data can also be exported in comma-separated values format
-(.csv), which in turn can be imported in many programs, such as MATLAB, R, Prism, etc.; This flag is set to ``False``
-by default. You can also set a destination folder (``destfolder``) for the output files by passing a path of the folder you wish to write to.
+deeplabcut.analyze_videos(
+ config_path,
+ videos,
+ videotype="avi",
+ shuffle=1,
+ trainingsetindex=0,
+ gputouse=None,
+ save_as_csv=False,
+ destfolder=None,
+ dynamic=(True, .5, 10)
+)
+```
+The user can choose a checkpoint for analyzing the videos. For this, the user can enter the corresponding index of the
+checkpoint to the variable snapshotindex in the config.yaml file. By default, the most recent checkpoint (i.e. last) is
+used for analyzing the video.
+The labels are stored in a MultiIndex [Pandas](http://pandas.pydata.org) Array, which contains the name of the network,
+body part name, (x, y) label position in pixels, and the likelihood for each frame per body part. These arrays are
+stored in an efficient Hierarchical Data Format (HDF) in the same directory, where the video is stored.
+However, if the flag `save_as_csv` is set to `True`, the data can also be exported in comma-separated values format
+(.csv), which in turn can be imported in many programs, such as MATLAB, R, Prism, etc.; This flag is set to `False`
+by default. You can also set a destination folder (`destfolder`) for the output files by passing a path of the folder
+you wish to write to.
#### API Docs
````{admonition} Click the button to see API Docs
@@ -394,26 +636,56 @@ by default. You can also set a destination folder (``destfolder``) for the outpu
#### Dynamic-cropping of videos:
-As of 2.1+ we have a dynamic cropping option. Namely, if you have large frames and the animal/object occupies a smaller fraction, you can crop around your animal/object to make processing speeds faster. For example, if you have a large open field experiment but only track the mouse, this will speed up your analysis (also helpful for real-time applications). To use this simply add ``dynamic=(True,.5,10)`` when you call ``analyze_videos``.
+As of 2.1+ we have a dynamic cropping option. Namely, if you have large frames and the animal/object occupies a smaller
+fraction, you can crop around your animal/object to make processing speeds faster. For example, if you have a large open
+field experiment but only track the mouse, this will speed up your analysis (also helpful for real-time applications).
+To use this simply add `dynamic=(True,.5,10)` when you call `analyze_videos`.
```python
dynamic: triple containing (state, detectiontreshold, margin)
- If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e., any body part > detectiontreshold), then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost; i.e., detectiontreshold),
+ then object boundaries are computed according to the smallest/largest x position and
+ smallest/largest y position of all body parts. This window is expanded by the margin
+ and from then on only the posture within this crop is analyzed (until the object is lost;
+ i.e., < detectiontreshold). The current position is utilized for updating the crop window
+ for the next frame (this is why the margin is important and should be set large enough
+ given the movement of the animal).
```
-### (J) Filter pose data (RECOMMENDED!):
+### (J) Filter Pose Data
You can also filter the predictions with a median filter (default) or with a [SARIMAX model](https://www.statsmodels.org/dev/generated/statsmodels.tsa.statespace.sarimax.SARIMAX.html), if you wish. This creates a new .h5 file with the ending *_filtered* that you can use in create_labeled_data and/or plot trajectories.
```python
-deeplabcut.filterpredictions(config_path, ['fullpath/analysis/project/videos/reachingvideo1.avi'])
+deeplabcut.filterpredictions(
+ config_path,
+ ["fullpath/analysis/project/videos/reachingvideo1.avi"]
+)
```
An example call:
- ```python
-deeplabcut.filterpredictions(config_path,['fullpath/analysis/project/videos'], videotype='.mp4',filtertype= 'arima',ARdegree=5,MAdegree=2)
- ```
+```python
+deeplabcut.filterpredictions(
+ config_path,
+ ["fullpath/analysis/project/videos"],
+ videotype=".mp4",
+ filtertype="arima",
+ ARdegree=5,
+ MAdegree=2
+)
+```
Here are parameters you can modify and pass:
```python
-deeplabcut.filterpredictions(config_path, ['fullpath/analysis/project/videos/reachingvideo1.avi'], shuffle=1, trainingsetindex=0, comparisonbodyparts='all', filtertype='arima', p_bound=0.01, ARdegree=3, MAdegree=1, alpha=0.01)
+deeplabcut.filterpredictions(
+ config_path,
+ ["fullpath/analysis/project/videos/reachingvideo1.avi"],
+ shuffle=1,
+ trainingsetindex=0,
+ filtertype="arima",
+ p_bound=0.01,
+ ARdegree=3,
+ MAdegree=1,
+ alpha=0.01
+)
```
Here is an example of how this can be applied to a video:
@@ -429,7 +701,7 @@ deeplabcut.filterpredictions(config_path, ['fullpath/analysis/project/videos/rea
```
````
-### (K) Plot Trajectories:
+### (K) Plot Trajectories
The plotting components of this toolbox utilizes matplotlib. Therefore, these plots can easily be customized by
the end user. We also provide a function to plot the trajectory of the extracted poses across the analyzed video, which
@@ -439,7 +711,11 @@ can be called by typing:
deeplabcut.plot_trajectories(config_path, [‘fullpath/analysis/project/videos/reachingvideo1.avi’])
```
-It creates a folder called ``plot-poses`` (in the directory of the video). The plots display the coordinates of body parts vs. time, likelihoods vs time, the x- vs. y- coordinate of the body parts, as well as histograms of consecutive coordinate differences. These plots help the user to quickly assess the tracking performance for a video. Ideally, the likelihood stays high and the histogram of consecutive coordinate differences has values close to zero (i.e. no jumps in body part detections across frames). Here are example plot outputs on a demo video (left):
+It creates a folder called `plot-poses` (in the directory of the video). The plots display the coordinates of body parts
+vs. time, likelihoods vs time, the x- vs. y- coordinate of the body parts, as well as histograms of consecutive
+coordinate differences. These plots help the user to quickly assess the tracking performance for a video. Ideally, the
+likelihood stays high and the histogram of consecutive coordinate differences has values close to zero (i.e. no jumps in
+body part detections across frames). Here are example plot outputs on a demo video (left):
@@ -454,51 +730,88 @@ It creates a folder called ``plot-poses`` (in the directory of the video). The p
```
````
-### (L) Create Labeled Videos:
+### (L) Create Labeled Videos
Additionally, the toolbox provides a function to create labeled videos based on the extracted poses by plotting the
-labels on top of the frame and creating a video. There are two modes to create videos: FAST and SLOW (but higher quality!). If you want to create high-quality videos, please add ``save_frames=True``. One can use the command as follows to create multiple labeled videos:
+labels on top of the frame and creating a video. There are two modes to create videos: FAST and SLOW (but higher
+quality!). One can use the command as follows to create multiple labeled videos:
```python
-deeplabcut.create_labeled_video(config_path, ['fullpath/analysis/project/videos/reachingvideo1.avi','fullpath/analysis/project/videos/reachingvideo2.avi'], save_frames = True/False)
-```
- Optionally, if you want to use the filtered data for a video or directory of filtered videos pass ``filtered=True``, i.e.:
+deeplabcut.create_labeled_video(
+ config_path,
+ ["fullpath/analysis/project/videos/reachingvideo1.avi",
+ "fullpath/analysis/project/videos/reachingvideo2.avi"],
+ save_frames = True/False
+)
+```
+ Optionally, if you want to use the filtered data for a video or directory of filtered videos pass `filtered=True`,
+ i.e.:
```python
-deeplabcut.create_labeled_video(config_path, ['fullpath/afolderofvideos'], videotype='.mp4', filtered=True)
-```
-You can also optionally add a skeleton to connect points and/or add a history of points for visualization. To set the "trailing points" you need to pass ``trailpoints``:
+deeplabcut.create_labeled_video(
+ config_path,
+ ["fullpath/afolderofvideos"],
+ videotype=".mp4",
+ filtered=True
+)
+```
+You can also optionally add a skeleton to connect points and/or add a history of points for visualization. To set the
+"trailing points" you need to pass `trailpoints`:
```python
-deeplabcut.create_labeled_video(config_path, ['fullpath/afolderofvideos'], videotype='.mp4', trailpoints=10)
-```
-To draw a skeleton, you need to first define the pairs of connected nodes (in the ``config.yaml`` file) and set the skeleton color (in the ``config.yaml`` file). There is also a GUI to help you do this, use by calling `deeplabcut.SkeletonBuilder(config+path)`!
-
-Here is how the ``config.yaml`` additions/edits should look (for example, on the Openfield demo data we provide):
+deeplabcut.create_labeled_video(
+ config_path,
+ ["fullpath/afolderofvideos"],
+ videotype=".mp4",
+ trailpoints=10
+)
+```
+To draw a skeleton, you need to first define the pairs of connected nodes (in the `config.yaml` file) and set the
+skeleton color (in the `config.yaml` file). There is also a GUI to help you do this, use by calling
+`deeplabcut.SkeletonBuilder(configpath)`!
+
+Here is how the `config.yaml` additions/edits should look (for example, on the Openfield demo data we provide):
```python
# Plotting configuration
-skeleton: [['snout', 'leftear'], ['snout', 'rightear'], ['leftear', 'tailbase'], ['leftear', 'rightear'], ['rightear','tailbase']]
+skeleton:
+ - ["snout", "leftear"]
+ - ["snout", "rightear"]
+ - ["leftear", "tailbase"]
+ - ["leftear", "rightear"]
+ - ["rightear", "tailbase"]
skeleton_color: white
pcutoff: 0.4
dotsize: 4
alphavalue: 0.5
colormap: jet
```
-Then pass ``draw_skeleton=True`` with the command:
+Then pass `draw_skeleton=True` with the command:
```python
-deeplabcut.create_labeled_video(config_path,['fullpath/afolderofvideos'], videotype='.mp4', draw_skeleton = True)
+deeplabcut.create_labeled_video(
+ config_path,
+ ["fullpath/afolderofvideos"],
+ videotype=".mp4",
+ draw_skeleton=True
+)
```
-**NEW** as of 2.2b8: You can create a video with only the "dots" plotted, i.e., in the [style of Johansson](https://link.springer.com/article/10.1007/BF00309043), by passing `keypoints_only=True`:
+**NEW** as of 2.2b8: You can create a video with only the "dots" plotted, i.e., in the
+[style of Johansson](https://link.springer.com/article/10.1007/BF00309043), by passing `keypoints_only=True`:
```python
-deeplabcut.create_labeled_video(config_path,['fullpath/afolderofvideos'], videotype='.mp4', keypoints_only=True)
+deeplabcut.create_labeled_video(
+ config_path,["fullpath/afolderofvideos"],
+ videotype=".mp4",
+ keypoints_only=True
+)
```
-**PRO TIP:** that the **best quality videos** are created when ``save_frames=True`` is passed. Therefore, when ``trailpoints`` and ``draw_skeleton`` are used, we **highly** recommend you also pass ``save_frames=True``!
+**PRO TIP:** that the **best quality videos** are created when `fastmode=False` is passed. Therefore, when
+`trailpoints` and `draw_skeleton` are used, we **highly** recommend you also pass `fastmode=False`!
-This function has various other parameters, in particular the user can set the ``colormap``, the ``dotsize``, and ``alphavalue`` of the labels in **config.yaml** file.
+This function has various other parameters, in particular the user can set the `colormap`, the `dotsize`, and
+`alphavalue` of the labels in **config.yaml** file.
#### API Docs
````{admonition} Click the button to see API Docs
@@ -510,10 +823,20 @@ This function has various other parameters, in particular the user can set the `
#### Extract "Skeleton" Features:
-NEW, as of 2.0.7+: You can save the "skeleton" that was applied in ``create_labeled_videos`` for more computations. Namely, it extracts length and orientation of each "bone" of the skeleton as defined in the **config.yaml** file. You can use the function by:
+NEW, as of 2.0.7+: You can save the "skeleton" that was applied in `create_labeled_videos` for more computations.
+Namely, it extracts length and orientation of each "bone" of the skeleton as defined in the **config.yaml** file. You
+can use the function by:
```python
-deeplabcut.analyzeskeleton(config, video, videotype='avi', shuffle=1, trainingsetindex=0, save_as_csv=False, destfolder=None)
+deeplabcut.analyzeskeleton(
+ config,
+ video,
+ videotype="avi",
+ shuffle=1,
+ trainingsetindex=0,
+ save_as_csv=False,
+ destfolder=None
+)
```
#### API Docs
@@ -524,6 +847,7 @@ deeplabcut.analyzeskeleton(config, video, videotype='avi', shuffle=1, trainingse
```
````
+(active-learning)=
### (M) Optional Active Learning -> Network Refinement: Extract Outlier Frames
While DeepLabCut typically generalizes well across datasets, one might want to optimize its performance in various,
@@ -539,37 +863,46 @@ where the decoder might make large errors.
All this can be done for a specific video by typing (see other optional inputs below):
```python
-deeplabcut.extract_outlier_frames(config_path, ['videofile_path'])
+deeplabcut.extract_outlier_frames(config_path, ["videofile_path"])
```
We provide various frame-selection methods for this purpose. In particular
the user can set:
```
-outlieralgorithm: 'fitting', 'jump', or 'uncertain'``
+outlieralgorithm: "fitting", "jump", or "uncertain"
```
-• select frames if the likelihood of a particular or all body parts lies below *pbound* (note this could also be due to
-occlusions rather than errors); (``outlieralgorithm='uncertain'``), but also set ``p_bound``.
+• `outlieralgorithm="uncertain"`: select frames if the likelihood of a particular or all body parts lies below `p_bound`
+(note this could also be due to occlusions rather than errors).
-• select frames where a particular body part or all body parts jumped more than *\uf* pixels from the last frame (``outlieralgorithm='jump'``).
+• `outlieralgorithm="jump"`: select frames where a particular body part or all body parts jumped more than `epsilon`
+pixels from the last frame.
-• select frames if the predicted body part location deviates from a state-space model fit to the time series
-of individual body parts. Specifically, this method fits an Auto Regressive Integrated Moving Average (ARIMA)
-model to the time series for each body part. Thereby each body part detection with a likelihood smaller than
-pbound is treated as missing data. Putative outlier frames are then identified as time points, where the average body part estimates are at least *\uf* pixel away from the fits. The parameters of this method are *\uf*, *pbound*, the ARIMA parameters as well as the list of body parts to average over (can also be ``all``).
+• `outlieralgorithm="fitting"`: select frames if the predicted body part location deviates from a state-space model fit
+to the time series of individual body parts. Specifically, this method fits an Auto Regressive Integrated Moving Average
+(ARIMA) model to the time series for each body part. Thereby each body part detection with a likelihood smaller than
+`p_bound` is treated as missing data. Putative outlier frames are then identified as time points, where the average
+body part estimates are at least `epsilon` pixels away from the fits. The parameters of this method are `epsilon`,
+`p_bound`, the ARIMA parameters as well as the list of body parts to average over (can also be `all`).
-• manually select outlier frames based on visual inspection from the user (``outlieralgorithm='manual'``).
+• `outlieralgorithm="manual"`: manually select outlier frames based on visual inspection from the user.
As an example:
```python
-deeplabcut.extract_outlier_frames(config_path, ['videofile_path'], outlieralgorithm='manual')
+deeplabcut.extract_outlier_frames(config_path, ["videofile_path"], outlieralgorithm="manual")
```
In general, depending on the parameters, these methods might return much more frames than the user wants to
-extract (``numframes2pick``). Thus, this list is then used to select outlier frames either by randomly sampling from this
-list (``extractionalgorithm='uniform'``), by performing ``extractionalgorithm='k-means'`` clustering on the corresponding frames.
+extract (`numframes2pick`). Thus, this list is then used to select outlier frames either by randomly sampling from
+this list (`extractionalgorithm="uniform"`), by performing `extractionalgorithm="kmeans"` clustering on the
+corresponding frames.
-In the automatic configuration, before the frame selection happens, the user is informed about the amount of frames satisfying the criteria and asked if the selection should proceed. This step allows the user to perhaps change the parameters of the frame-selection heuristics first (i.e. to make sure that not too many frames are qualified). The user can run the extract_outlier_frames iteratively, and (even) extract additional frames from the same video. Once enough outlier frames are extracted the refinement GUI can be used to adjust the labels based on user feedback (see below).
+In the automatic configuration, before the frame selection happens, the user is informed about the amount of frames
+satisfying the criteria and asked if the selection should proceed. This step allows the user to perhaps change the
+parameters of the frame-selection heuristics first (i.e. to make sure that not too many frames are qualified). The user
+can run the `extract_outlier_frames` method iteratively, and (even) extract additional frames from the same video.
+Once enough outlier frames are extracted the refinement GUI can be used to adjust the labels based on user feedback
+(see below).
#### API Docs
````{admonition} Click the button to see API Docs
@@ -603,7 +936,7 @@ deeplabcut.refine_labels(config_path)
```
This will launch a GUI where the user can refine the labels.
-Use the ‘Load Labels’ button to select one of the subdirectories, where the extracted frames are stored. Every label will be identified by a unique color. For better chances to identify the low-confidence labels, specify the threshold of the likelihood. This changes the body parts with likelihood below this threshold to appear as circles and the ones above as solid disks while retaining the same color scheme. Next, to adjust the position of the label, hover the mouse over the labels to identify the specific body part, left click and drag it to a different location. To delete a specific label, middle click on the label (once a label is deleted, it cannot be retrieved).
+Please refer to the [napari-deeplabcut docs](napari-gui) for more information about the labelling workflow.
After correcting the labels for all the frames in each of the subdirectories, the users should merge the data set to
create a new dataset. In this step the iteration parameter in the config.yaml file is automatically updated.
@@ -612,13 +945,16 @@ deeplabcut.merge_datasets(config_path)
```
Once the dataset is merged, the user can test if the merging process was successful by plotting all the labels (Step E).
Next, with this expanded training set the user can now create a novel training set and train the network as described
-in Steps F and G. The training dataset will be stored in the same place as before but under a different ``iteration #``
-subdirectory, where the ``#`` is the new value of ``iteration`` variable stored in the project’s configuration file (this is
-automatically done).
+in Steps F and G. The training dataset will be stored in the same place as before but under a different `iteration-#`
+subdirectory, where the ``#`` is the new value of `iteration` variable stored in the project’s configuration file
+(this is automatically done).
-Now you can run ``create_training_dataset``, then ``train_network``, etc. If your original labels were adjusted at all, start from fresh weights (the typically recommended path anyhow), otherwise consider using your already trained network weights (see Box 2).
+Now you can run `create_training_dataset`, then `train_network`, etc. If your original labels were adjusted at all,
+start from fresh weights (the typically recommended path anyhow), otherwise consider using your already trained network
+weights (see Box 2).
-If after training the network generalizes well to the data, proceed to analyze new videos. Otherwise, consider labeling more data.
+If after training the network generalizes well to the data, proceed to analyze new videos. Otherwise, consider labeling
+more data.
#### API Docs for deeplabcut.refine_labels
````{admonition} Click the button to see API Docs
@@ -639,12 +975,15 @@ If after training the network generalizes well to the data, proceed to analyze n
### Jupyter Notebooks for Demonstration of the DeepLabCut Workflow
We also provide two Jupyter notebooks for using DeepLabCut on both a pre-labeled dataset, and on the end user’s
-own dataset. Firstly, we prepared an interactive Jupyter notebook called run_yourowndata.ipynb that can serve as a
-template for the user to develop a project. Furthermore, we provide a notebook for an already started project with
-labeled data. The example project, named as Reaching-Mackenzie-2018-08-30 consists of a project configuration file
-with default parameters and 20 images, which are cropped around the region of interest as an example dataset. These
-images are extracted from a video, which was recorded in a study of skilled motor control in mice. Some example
-labels for these images are also provided. See more details [here](https://github.com/DeepLabCut/DeepLabCut/blob/master/examples).
+own dataset. Firstly, we prepared an interactive Jupyter notebook called
+[Demo_yourowndata.ipynb](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/JUPYTER/Demo_yourowndata.ipynb)
+that can serve as a template for the user to develop a project. Furthermore, we provide a notebook for an already
+started project with labeled data. The example project, named as
+[Reaching-Mackenzie-2018-08-30](https://github.com/DeepLabCut/DeepLabCut/tree/main/examples/Reaching-Mackenzie-2018-08-30)
+consists of a project configuration file with default parameters and 20 images, which are cropped around the region of
+interest as an example dataset. These images are extracted from a video, which was recorded in a study of skilled motor
+control in mice. Some example labels for these images are also provided. See more details
+[here](https://github.com/DeepLabCut/DeepLabCut/tree/main/examples).
## 3D Toolbox
diff --git a/examples/COLAB/COLAB_3miceDemo.ipynb b/examples/COLAB/COLAB_3miceDemo.ipynb
index d914c409fc..b3977a68bd 100644
--- a/examples/COLAB/COLAB_3miceDemo.ipynb
+++ b/examples/COLAB/COLAB_3miceDemo.ipynb
@@ -16,18 +16,22 @@
"id": "TGChzLdc-lUJ"
},
"source": [
- "# DeepLabCut 2.2 Toolbox Demo on 3 mice data\n",
+ "# DeepLabCut 2.2+ Toolbox Demo on 3 mice data\n",
"\n",
"\n",
"https://github.com/DeepLabCut/DeepLabCut\n",
"\n",
+ "Note: this Colab notebook was written to accompany the Nature Methods publication [_Multi-animal pose estimation, identification and tracking with DeepLabCut_](https://www.nature.com/articles/s41592-022-01443-0) with the TensorFlow engine. To learn about DeepLabCut 3.0+ and the PyTorch engine, you can check out our other notebooks (such as [`COLAB_YOURDATA_maDLC_TrainNetwork_VideoAnalysis.ipynb`](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_YOURDATA_maDLC_TrainNetwork_VideoAnalysis.ipynb)).\n",
+ "\n",
"### This notebook illustrates how to use COLAB for a multi-animal DeepLabCut (maDLC) Demo 3 mouse project:\n",
+ "\n",
"- load our mini-demo data that includes a pretrained model and unlabeled video.\n",
"- analyze a novel video.\n",
"- assemble animals and tracklets.\n",
"- create quality check plots and video.\n",
"\n",
"### To create a full maDLC pipeline please see our full docs: https://deeplabcut.github.io/DeepLabCut/README.html\n",
+ "\n",
"- Of interest is a full how-to for maDLC: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html\n",
"- a quick guide to maDLC: https://deeplabcut.github.io/DeepLabCut/docs/tutorial.html\n",
"- a demo COLAB for how to use maDLC on your own data: https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/COLAB/COLAB_maDLC_TrainNetwork_VideoAnalysis.ipynb\n",
diff --git a/examples/COLAB/COLAB_DEMO_SuperAnimal.ipynb b/examples/COLAB/COLAB_DEMO_SuperAnimal.ipynb
index 7f23ef5edb..3f9ce58556 100644
--- a/examples/COLAB/COLAB_DEMO_SuperAnimal.ipynb
+++ b/examples/COLAB/COLAB_DEMO_SuperAnimal.ipynb
@@ -1,261 +1,231 @@
{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3G1Nx3YLOVaZ"
- },
- "source": [
- "\n",
- " \n",
- " "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "23v-XAUNQIPY"
- },
- "source": [
- "# DeepLabCut Model Zoo: SuperAnimal models\n",
- "\n",
- "\n",
- "\n",
- "http://modelzoo.deeplabcut.org\n",
- "\n",
- "You can use this notebook to analyze videos with pretrained networks from our model zoo - NO local installation of DeepLabCut is needed!\n",
- "\n",
- "- **What you need:** a video of your favorite dog, cat, human, etc: check the list of currently available models here: http://modelzoo.deeplabcut.org\n",
- "\n",
- "- **What to do:** (1) in the top right corner, click \"CONNECT\". Then, just hit run (play icon) on each cell below and follow the instructions!\n",
- "\n",
- "## **Please consider giving back and labeling a little data to help make each network even better!**\n",
- "\n",
- "We have a WebApp, so no need to install anything, just a few clicks! We'd really appreciate your help! 🙏\n",
- " \n",
- "https://contrib.deeplabcut.org/\n",
- "\n",
- "\n",
- "- **Note, if you performance is less that you would like:** firstly check the labeled_video parameters (i.e. \"pcutoff\" that will set the video plotting) - see the end of this notebook.\n",
- "- You can also use the model in your own projects locally. Please be sure to cite the papers for the model, i.e., [Ye et al. 2023](https://arxiv.org/abs/2203.07436) 🎉\n",
- "\n",
- "\n",
- "\n",
- "## **Let's get going: install DeepLabCut into COLAB:**\n",
- "\n",
- "*Also, be sure you are connected to a GPU: go to menu, click Runtime > Change Runtime Type > select \"GPU\"*\n",
- "\n",
- "As the COLAB environments were updated to CUDA 12.X and Python 3.11, we need to install DeepLabCut and TensorFlow in a distinct way to get TensorFlow to connect to the GPU.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Cell 1 - Install TensorFlow, tensorpack and tf_slim versions compatible with DeepLabCut\n",
- "!pip install \"tensorflow==2.12.1\" \"tensorpack>=0.11\" \"tf_slim>=1.1.0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Downgrade PyTorch to a version using CUDA 11.8 and cudnn 8\n",
- "# This will also install the required CUDA libraries, for both PyTorch and TensorFlow\n",
- "!pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cu118"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Install the latest version of DeepLabCut\n",
- "!pip install \"git+https://github.com/DeepLabCut/DeepLabCut.git#egg=deeplabcut[modelzoo]\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# As described in https://www.tensorflow.org/install/pip#step-by-step_instructions, \n",
- "# create symbolic links to NVIDIA shared libraries:\n",
- "!ln -svf /usr/local/lib/python3.11/dist-packages/nvidia/*/lib/*.so* /usr/local/lib/python3.11/dist-packages/tensorflow"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TguLMTJpQx1_"
- },
- "source": [
- "## PLEASE, click \"restart runtime\" from the output above before proceeding!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4BejjXKFO2Zg"
- },
- "outputs": [],
- "source": [
- "import deeplabcut\n",
- "import os"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "GXf8N4v28Xqo"
- },
- "source": [
- "## Please select a video you want to run SuperAnimal-X on:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "xXNMNLe6xEBC"
- },
- "outputs": [],
- "source": [
- "from google.colab import files\n",
- "\n",
- "uploaded = files.upload()\n",
- "for filepath, content in uploaded.items():\n",
- " print(f'User uploaded file \"{filepath}\" with length {len(content)} bytes')\n",
- "video_path = os.path.abspath(filepath)\n",
- "video_name = os.path.splitext(video_path)[0]\n",
- "\n",
- "# If this cell fails (e.g., when using Safari in place of Google Chrome),\n",
- "# manually upload your video via the Files menu to the left\n",
- "# and define `video_path` yourself with right click > copy path on the video."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "A8sDYMa08f62"
- },
- "source": [
- "## Next select the model you want to use, Quadruped or TopViewMouse\n",
- "- See http://modelzoo.deeplabcut.org/ for more details on these models\n",
- "- The pcutoff is for visualization only, namely only keypoints with a value over what you set are shown. 0 is low confidience, 1 is perfect confidience of the model."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ge589yC4v9yX"
- },
- "outputs": [],
- "source": [
- "supermodel_name = \"superanimal_quadruped\" #@param [\"superanimal_topviewmouse\", \"superanimal_quadruped\"]\n",
- "pcutoff = 0.3 #@param {type:\"slider\", min:0, max:1, step:0.05}"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "zsB0pGtj9Luq"
- },
- "source": [
- "## Okay, let's go! 🐭🦓🐻"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "yqcnEVVSQDC0"
- },
- "outputs": [],
- "source": [
- "videotype = os.path.splitext(video_path)[1]\n",
- "scale_list = []\n",
- "\n",
- "deeplabcut.video_inference_superanimal(\n",
- " [video_path],\n",
- " supermodel_name,\n",
- " videotype=videotype,\n",
- " video_adapt=True,\n",
- " scale_list=scale_list,\n",
- " pcutoff=pcutoff,\n",
- ")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gPLZSBpD34Mj"
- },
- "source": [
- "## Let's view the video in Colab:\n",
- "- otherwise, you can download and look at the video from the left side of your screen! It will end with _labeled.mp4\n",
- "- If your data doesn't work as well as you'd like, consider fine-tuning our model on your data, changing the pcutoff, changing the scale-range\n",
- "(pick values smaller and larger than your video image input size). See our repo for more details."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "ejFJ1Pbg33i6"
- },
- "outputs": [],
- "source": [
- "from base64 import b64encode\n",
- "from IPython.display import HTML\n",
- "view_video = open(video_name+'DLC_snapshot-1000_labeled.mp4','rb').read()\n",
- "\n",
- "data_url = \"data:video/mp4;base64,\" + b64encode(view_video).decode()\n",
- "HTML(\"\"\"\n",
- "\n",
- " \n",
- " \n",
- "\"\"\" % data_url)"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "provenance": []
- },
- "gpuClass": "standard",
- "kernelspec": {
- "display_name": "dlc",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]"
- },
- "vscode": {
- "interpreter": {
- "hash": "ef00193d8f29a47f592f520086c931b5dd2a83e8a593fa0efe5afff3c413a788"
- }
- }
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3G1Nx3YLOVaZ"
+ },
+ "source": [
+ "\n",
+ " \n",
+ " "
+ ]
},
- "nbformat": 4,
- "nbformat_minor": 0
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "23v-XAUNQIPY"
+ },
+ "source": [
+ "# DeepLabCut Model Zoo: SuperAnimal models\n",
+ "\n",
+ "\n",
+ "\n",
+ "http://modelzoo.deeplabcut.org\n",
+ "\n",
+ "You can use this notebook to analyze videos with pretrained networks from our model zoo - NO local installation of DeepLabCut is needed!\n",
+ "\n",
+ "- **What you need:** a video of your favorite dog, cat, human, etc: check the list of currently available models here: http://modelzoo.deeplabcut.org\n",
+ "\n",
+ "- **What to do:** (1) in the top right corner, click \"CONNECT\". Then, just hit run (play icon) on each cell below and follow the instructions!\n",
+ "\n",
+ "## **Please consider giving back and labeling a little data to help make each network even better!**\n",
+ "\n",
+ "We have a WebApp, so no need to install anything, just a few clicks! We'd really appreciate your help! 🙏\n",
+ " \n",
+ "https://contrib.deeplabcut.org/\n",
+ "\n",
+ "\n",
+ "- **Note, if you performance is less that you would like:** firstly check the labeled_video parameters (i.e. \"pcutoff\" that will set the video plotting) - see the end of this notebook.\n",
+ "- You can also use the model in your own projects locally. Please be sure to cite the papers for the model, i.e., [Ye et al. 2023](https://arxiv.org/abs/2203.07436) 🎉\n",
+ "\n",
+ "\n",
+ "\n",
+ "## **Let's get going: install DeepLabCut into COLAB:**\n",
+ "\n",
+ "*Also, be sure you are connected to a GPU: go to menu, click Runtime > Change Runtime Type > select \"GPU\"*\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "03ylSyQ4O9Ee"
+ },
+ "outputs": [],
+ "source": "!pip install --pre deeplabcut"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "TguLMTJpQx1_"
+ },
+ "source": [
+ "## PLEASE, click \"restart runtime\" from the output above before proceeding!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "4BejjXKFO2Zg"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "import deeplabcut"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GXf8N4v28Xqo"
+ },
+ "source": [
+ "## Please select a video you want to run SuperAnimal-X on:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xXNMNLe6xEBC"
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import files\n",
+ "\n",
+ "uploaded = files.upload()\n",
+ "for filepath, content in uploaded.items():\n",
+ " print(f'User uploaded file \"{filepath}\" with length {len(content)} bytes')\n",
+ "video_path = os.path.abspath(filepath)\n",
+ "video_name = os.path.splitext(video_path)[0]\n",
+ "\n",
+ "# If this cell fails (e.g., when using Safari in place of Google Chrome),\n",
+ "# manually upload your video via the Files menu to the left\n",
+ "# and define `video_path` yourself with right click > copy path on the video."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "A8sDYMa08f62"
+ },
+ "source": [
+ "## Next select the model you want to use, Quadruped or TopViewMouse\n",
+ "- See http://modelzoo.deeplabcut.org/ for more details on these models\n",
+ "- The pcutoff is for visualization only, namely only keypoints with a value over what you set are shown. 0 is low confidience, 1 is perfect confidience of the model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ge589yC4v9yX"
+ },
+ "outputs": [],
+ "source": [
+ "supermodel_name = \"superanimal_quadruped\" #@param [\"superanimal_topviewmouse\", \"superanimal_quadruped\"]\n",
+ "pcutoff = 0.15 #@param {type:\"slider\", min:0, max:1, step:0.05}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zsB0pGtj9Luq"
+ },
+ "source": [
+ "## Okay, let's go! 🐭🦓🐻"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "yqcnEVVSQDC0"
+ },
+ "outputs": [],
+ "source": [
+ "videotype = os.path.splitext(video_path)[1]\n",
+ "scale_list = []\n",
+ "\n",
+ "deeplabcut.video_inference_superanimal(\n",
+ " [video_path],\n",
+ " supermodel_name,\n",
+ " model_name=\"hrnet_w32\",\n",
+ " detector_name=\"fasterrcnn_resnet50_fpn_v2\",\n",
+ " videotype=videotype,\n",
+ " video_adapt=True,\n",
+ " scale_list=scale_list,\n",
+ " pcutoff=pcutoff,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gPLZSBpD34Mj"
+ },
+ "source": [
+ "## Let's view the video in Colab:\n",
+ "- otherwise, you can download and look at the video from the left side of your screen! It will end with _labeled.mp4\n",
+ "- If your data doesn't work as well as you'd like, consider fine-tuning our model on your data, changing the pcutoff, changing the scale-range\n",
+ "(pick values smaller and larger than your video image input size). See our repo for more details."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "ejFJ1Pbg33i6"
+ },
+ "outputs": [],
+ "source": [
+ "from base64 import b64encode\n",
+ "from IPython.display import HTML\n",
+ "\n",
+ "labeled_video_path = f\"{video_name}_superanimal_quadruped_hrnetw32_labeled_after_adapt.mp4\"\n",
+ "view_video = open(labeled_video_path, \"rb\").read()\n",
+ "\n",
+ "data_url = \"data:video/mp4;base64,\" + b64encode(view_video).decode()\n",
+ "HTML(\"\"\"\n",
+ "\n",
+ " \n",
+ " \n",
+ "\"\"\" % data_url)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": []
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "dlc",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 17:01:00) \n[Clang 13.0.1 ]"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "ef00193d8f29a47f592f520086c931b5dd2a83e8a593fa0efe5afff3c413a788"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
diff --git a/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb b/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb
index a81190bb53..b6abdcc0ae 100644
--- a/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb
+++ b/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb
@@ -16,8 +16,12 @@
"id": "TGChzLdc-lUJ"
},
"source": [
- "# DeepLabCut Toolbox - Colab Demo on Topview Mouse Data\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "# DeepLabCut 3.0 Toolbox - Colab Demo on TopView Mouse Data\n",
+ "\n",
+ "Some useful links:\n",
+ "\n",
+ "- [DeepLabCut's GitHub: github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Single Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html)\n",
"\n",
"\n",
"\n",
@@ -42,41 +46,9 @@
"id": "txoddlM8hLKm"
},
"source": [
- "## First, go to \"Runtime\" ->\"change runtime type\"->select \"Python3\", and then select \"GPU\"\n",
+ "## Installation\n",
"\n",
- "As the COLAB environments were updated to CUDA 12.X and Python 3.11, we need to install DeepLabCut and TensorFlow in a distinct way to get TensorFlow to connect to the GPU."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Install TensorFlow, tensorpack and tf_slim versions compatible with DeepLabCut\n",
- "!pip install \"tensorflow==2.12.1\" \"tensorpack>=0.11\" \"tf_slim>=1.1.0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Downgrade PyTorch to a version using CUDA 11.8 and cudnn 8\n",
- "# This will also install the required CUDA libraries, for both PyTorch and TensorFlow\n",
- "!pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cu118"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# As described in https://www.tensorflow.org/install/pip#step-by-step_instructions, \n",
- "# create symbolic links to NVIDIA shared libraries:\n",
- "!ln -svf /usr/local/lib/python3.11/dist-packages/nvidia/*/lib/*.so* /usr/local/lib/python3.11/dist-packages/tensorflow"
+ "### First, go to \"Runtime\" ->\"change runtime type\"->select \"Python3\", and then select \"GPU\""
]
},
{
@@ -115,7 +87,7 @@
"source": [
"# Install the latest DeepLabCut version (this will take a few minutes to install all the dependencies!)\n",
"%cd /content/cloned-DLC-repo/\n",
- "!pip install \".\""
+ "%pip install \".\""
]
},
{
@@ -124,7 +96,7 @@
"id": "XymV_Hnlp1OJ"
},
"source": [
- "## PLEASE, click \"restart runtime\" from the output above before proceeding! "
+ "### PLEASE, click \"restart runtime\" from the output above before proceeding!"
]
},
{
@@ -146,7 +118,7 @@
},
"outputs": [],
"source": [
- "#create a path variable that links to the config file:\n",
+ "# Create a path variable that links to the config file:\n",
"path_config_file = '/content/cloned-DLC-repo/examples/openfield-Pranav-2018-10-30/config.yaml'\n",
"\n",
"# Loading example data set:\n",
@@ -171,14 +143,26 @@
},
"outputs": [],
"source": [
- "#let's also change the display and save_iters just in case Colab takes away the GPU... \n",
- "#if that happens, you can reload from a saved point. Typically, you want to train to 200,000 + iterations.\n",
- "#more info and there are more things you can set: https://github.com/DeepLabCut/DeepLabCut/wiki/DOCSTRINGS#train_network\n",
+ "# Let's also change the display and save_epochs just in case Colab takes away\n",
+ "# the GPU... If that happens, you can reload from a saved point using the\n",
+ "# `snapshot_path` argument to `deeplabcut.train_network`:\n",
+ "# deeplabcut.train_network(..., snapshot_path=\"/content/.../snapshot-050.pt\")\n",
+ "\n",
+ "# Typically, you want to train to ~200 epochs. We set the batch size to 8 to\n",
+ "# utilize the GPU's capabilities.\n",
"\n",
- "deeplabcut.train_network(path_config_file, shuffle=1, displayiters=100,saveiters=500, maxiters=10000)\n",
+ "# More info and there are more things you can set:\n",
+ "# https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#g-train-the-network\n",
"\n",
- "#this will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end (default, 1.03M iterations). \n",
- "#Whichever you chose, you will see what looks like an error message, but it's not an error - don't worry...."
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=5,\n",
+ " epochs=200,\n",
+ " batch_size=8,\n",
+ ")\n",
+ "\n",
+ "# This will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end."
]
},
{
@@ -187,7 +171,9 @@
"id": "RiDwIVf5-3H_"
},
"source": [
- "We recommend you run this for ~1,000 iterations, just as a demo. This should take around 20 min. Note, that **when you hit \"STOP\" you will get a KeyInterrupt \"error\"! No worries! :)**"
+ "We recommend you run this for ~100 epochs, just as a demo. This should take around 15 minutes. Note, that **when you hit \"STOP\" you will get a `KeyboardInterrupt` \"error\"! No worries! :)**\n",
+ "\n",
+ "A new snapshot is saved every `save_epochs` epochs. So once you hit 80 epochs, your latest snapshot in `/content/cloned-DLC-repo/examples/openfield-Pranav-2018-10-30/dlc-models-pytorch/iteration-0/openfieldOct30-trainset95shuffle1/train` should be `snapshot-80.pt`. The best snapshot evaluated during training is saved, and is named `snapshot-best-XX.pt`, where `XX` is the number of epochs the model was trained with."
]
},
{
@@ -209,10 +195,10 @@
},
"outputs": [],
"source": [
- "%matplotlib notebook\n",
- "deeplabcut.evaluate_network(path_config_file,plotting=True)\n",
+ "deeplabcut.evaluate_network(path_config_file, plotting=True)\n",
"\n",
- "# Here you want to see a low pixel error! Of course, it can only be as good as the labeler, so be sure your labels are good!"
+ "# Here you want to see a low pixel error! Of course, it can only be as\n",
+ "# good as the labeler, so be sure your labels are good!"
]
},
{
@@ -222,7 +208,8 @@
},
"source": [
"**Check the images**:\n",
- "You can go look in the newly created \"evalutaion-results\" folder at the images. At around 3500 iterations, the error is ~3 pixels (but this can vary on how your demo data was split for training)"
+ "\n",
+ "You can go look in the newly created `\"evalutaion-results-pytorch\"` folder at the images. At around 100 epochs, the error is ~3 pixels (but this can vary on how your demo data was split for training)."
]
},
{
@@ -236,7 +223,7 @@
"\n",
"The results are stored in hd5 file in the same directory where the video resides. \n",
"\n",
- "**On the demo data, this should take around ~ 3 min! (The demo frames are 640x480, which should run around 35 FPS on the google-provided GPU)**"
+ "**On the demo data, this should take around ~ 90 seconds! (The demo frames are 640x480, which should run around 25 FPS on the google-provided T4 GPU)**"
]
},
{
@@ -247,8 +234,9 @@
},
"outputs": [],
"source": [
- "videofile_path = ['/content/cloned-DLC-repo/examples/openfield-Pranav-2018-10-30/videos/m3v1mp4.mp4'] #Enter the list of videos to analyze.\n",
- "deeplabcut.analyze_videos(path_config_file,videofile_path, videotype='.mp4')"
+ "# Enter the list of videos to analyze.\n",
+ "videofile_path = [\"/content/cloned-DLC-repo/examples/openfield-Pranav-2018-10-30/videos/m3v1mp4.mp4\"]\n",
+ "deeplabcut.analyze_videos(path_config_file, videofile_path, videotype=\".mp4\")"
]
},
{
@@ -269,7 +257,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,videofile_path)"
+ "deeplabcut.create_labeled_video(path_config_file, videofile_path)"
]
},
{
@@ -290,7 +278,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.plot_trajectories(path_config_file,videofile_path)"
+ "deeplabcut.plot_trajectories(path_config_file, videofile_path)"
]
}
],
diff --git a/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb b/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb
new file mode 100644
index 0000000000..e16d2b055f
--- /dev/null
+++ b/examples/COLAB/COLAB_YOURDATA_SuperAnimal.ipynb
@@ -0,0 +1,2215 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github"
+ },
+ "source": [
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5SSZpZUu0Z4S"
+ },
+ "source": [
+ "# DeepLabCut Model Zoo: SuperAnimal models\n",
+ "\n",
+ "\n",
+ "\n",
+ "# 🦄 SuperAnimal in DeepLabCut PyTorch! 🔥\n",
+ "\n",
+ "This notebook demos how to use our SuperAnimal models within DeepLabCut 3.0! Please read more in [Ye et al. Nature Communications 2024](https://www.nature.com/articles/s41467-024-48792-2) about the available SuperAnimal models, and follow along below!\n",
+ "\n",
+ "### **Let's get going: install the latest version of DeepLabCut into COLAB:**\n",
+ "\n",
+ "*Also, be sure you are connected to a GPU: go to menu, click Runtime > Change Runtime Type > select \"GPU\"*\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "AjET5cJE5UYM",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "290a589f-a063-4933-d315-e13052ec1024"
+ },
+ "outputs": [],
+ "source": "!pip install --pre deeplabcut"
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5h0vq6E50Z4W"
+ },
+ "source": [
+ "**PLEASE, click \"restart runtime\" from the output above before proceeding!**"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "LvnlIvQm0Z4X",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "ef4fd2ed-4569-41d4-b78a-8bf5ae9a0e6b"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from PIL import Image\n",
+ "\n",
+ "import deeplabcut\n",
+ "import deeplabcut.utils.auxiliaryfunctions as auxiliaryfunctions\n",
+ "from deeplabcut.pose_estimation_pytorch.apis import (\n",
+ " superanimal_analyze_images,\n",
+ ")\n",
+ "from deeplabcut.modelzoo import build_weight_init\n",
+ "from deeplabcut.modelzoo.utils import (\n",
+ " create_conversion_table,\n",
+ " read_conversion_table_from_csv,\n",
+ ")\n",
+ "from deeplabcut.modelzoo.video_inference import video_inference_superanimal\n",
+ "from deeplabcut.utils.pseudo_label import keypoint_matching"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UeXjmtu40Z4X"
+ },
+ "source": [
+ "## Zero-shot Image & Video Inference\n",
+ "SuperAnimal models are foundation animal pose models. They can be used for zero-shot predictions without further training on the data.\n",
+ "In this section, we show how to use SuperAnimal models to predict pose from images (given an image folder) and output the predicted images (with pose) into another destination folder."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FvFzntDMxPoL"
+ },
+ "source": [
+ "### Zero-shot image inference\n",
+ "\n",
+ "If you have a single Image you want to test, upload it here!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "NbDsZQfsxPoL"
+ },
+ "source": [
+ "#### Upload the images you want to predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "c4yfTj7r0Z4Y",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import files\n",
+ "\n",
+ "uploaded = files.upload()\n",
+ "for filepath, content in uploaded.items():\n",
+ " print(f\"User uploaded file '{filepath}' with length {len(content)} bytes\")\n",
+ "image_path = os.path.abspath(filepath)\n",
+ "image_name = os.path.splitext(image_path)[0]\n",
+ "\n",
+ "# If this cell fails (e.g., when using Safari in place of Google Chrome),\n",
+ "# manually upload your video via the Files menu to the left\n",
+ "# and define `image_path` yourself with right click > copy path on the image:\n",
+ "#\n",
+ "# image_path = \"/path/to/my/image.png\"\n",
+ "# image_name = os.path.splitext(image_path)[0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Jashzdjb0Z4Y"
+ },
+ "source": [
+ "#### Select a SuperAnimal name and corresponding model architecture\n",
+ "\n",
+ "Check Our Docs on [SuperAnimals](https://github.com/DeepLabCut/DeepLabCut/blob/main/docs/ModelZoo.md) to learn more!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "uH9LXig90Z4Y",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown ---\n",
+ "# @markdown SuperAnimal Configurations\n",
+ "superanimal_name = \"superanimal_topviewmouse\" #@param [\"superanimal_topviewmouse\", \"superanimal_quadruped\"]\n",
+ "model_name = \"hrnet_w32\" #@param [\"hrnet_w32\"]\n",
+ "detector_name = \"fasterrcnn_resnet50_fpn_v2\" #@param [\"fasterrcnn_resnet50_fpn_v2\"]\n",
+ "\n",
+ "# @markdown ---\n",
+ "# @markdown What is the maximum number of animals you expect to have in an image\n",
+ "max_individuals = 3 # @param {type:\"slider\", min:1, max:30, step:1}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OmJtVmHq0Z4Y",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Note you need to enter max_individuals correctly to get the correct number of predictions in the image.\n",
+ "_ = superanimal_analyze_images(\n",
+ " superanimal_name,\n",
+ " model_name,\n",
+ " detector_name,\n",
+ " image_path,\n",
+ " max_individuals,\n",
+ " out_folder=\"/content/\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6VEjHu-00Z4Y"
+ },
+ "source": [
+ "### Zero-shot Video Inference\n",
+ "\n",
+ "This can be done with or without video adaptation (faster, but not self-supervised fine-tuned on your data!)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "qGoAhxZOxPoM"
+ },
+ "source": [
+ "#### Upload a video you want to predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "PK3efA0I0Z4Y",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from google.colab import files\n",
+ "\n",
+ "uploaded = files.upload()\n",
+ "for filepath, content in uploaded.items():\n",
+ " print(f\"User uploaded file '{filepath}' with length {len(content)} bytes\")\n",
+ "video_path = os.path.abspath(filepath)\n",
+ "video_name = os.path.splitext(video_path)[0]\n",
+ "\n",
+ "# If this cell fails (e.g., when using Safari in place of Google Chrome),\n",
+ "# manually upload your video via the Files menu to the left\n",
+ "# and define `video_path` yourself with right click > copy path on the video."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "JoA-RATSICj_"
+ },
+ "source": [
+ "#### Choose the superanimal and the model name"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "OiRAP9XD0Z4Z",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown ---\n",
+ "# @markdown SuperAnimal Configurations\n",
+ "superanimal_name = \"superanimal_topviewmouse\" #@param [\"superanimal_topviewmouse\", \"superanimal_quadruped\"]\n",
+ "model_name = \"hrnet_w32\" #@param [\"hrnet_w32\"]\n",
+ "detector_name = \"fasterrcnn_resnet50_fpn_v2\" #@param [\"fasterrcnn_resnet50_fpn_v2\"]\n",
+ "\n",
+ "# @markdown ---\n",
+ "# @markdown What is the maximum number of animals you expect to have in an image\n",
+ "max_individuals = 3 # @param {type:\"slider\", min:1, max:30, step:1}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Zv3v0QgSJNOg"
+ },
+ "source": [
+ "#### Zero-shot Video Inference without video adaptation\n",
+ "\n",
+ "The labeled video (and pose predictions for the video) are saved in `\"/content/\"`, with the labeled video name being `{your_video_name}_superanimal_{superanimal_name}_hrnetw32_labeled.mp4`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "poqynL0UJTBp",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "_ = video_inference_superanimal(\n",
+ " videos=video_path,\n",
+ " superanimal_name=superanimal_name,\n",
+ " model_name=model_name,\n",
+ " detector_name=detector_name,\n",
+ " video_adapt=False,\n",
+ " max_individuals=max_individuals,\n",
+ " dest_folder=\"/content/\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Z8Z5GSti0Z4Z"
+ },
+ "source": [
+ "#### Zero-shot Video Inference with video adaptation (unsupervised)\n",
+ "\n",
+ "The labeled video (and pose predictions for the video) are saved in `\"/content/\"`, with the labeled video name being `{your_video_name}_superanimal_{superanimal_name}_hrnetw32_labeled_after_adapt.mp4`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "5mhOmtzw0Z4Z",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "_ = video_inference_superanimal(\n",
+ " videos=[video_path],\n",
+ " superanimal_name=superanimal_name,\n",
+ " model_name=model_name,\n",
+ " detector_name=detector_name,\n",
+ " video_adapt=True,\n",
+ " max_individuals=max_individuals,\n",
+ " pseudo_threshold=0.1,\n",
+ " bbox_threshold=0.9,\n",
+ " detector_epochs=1,\n",
+ " pose_epochs=1,\n",
+ " dest_folder=\"/content/\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "br3pwGf40Z4a"
+ },
+ "source": [
+ "## Training with SuperAnimal\n",
+ "\n",
+ "In this section, we compare different ways to train models in DeepLabCut 3.0, with or without using SuperAnimal-pretrained models.\n",
+ "You can compare the evaluation results and get a sense of each baseline. We have following baselines:\n",
+ "\n",
+ "- ImageNet transfer learning (training without superanimal)\n",
+ "- SuperAnimal transfer learning (baseline 1)\n",
+ "- SuperAnimal naive fine-tuning (baseline 2)\n",
+ "- SuperAnimal memory-replay fine-tuning (baseline3)\n",
+ "\n",
+ "This is done on one of your DeepLabCut projects! If you don't have a DeepLabCut project that you can use SuperAnimal models with, you can always using the example openfield dataset [available in the DeepLabCut repository](https://github.com/DeepLabCut/DeepLabCut/tree/main/examples/openfield-Pranav-2018-10-30) or the Tri-Mouse dataset available on [Zenodo](https://zenodo.org/records/5851157)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "yPy5VgDDhD6o"
+ },
+ "source": [
+ "### Preparing the DeepLabCut Project\n",
+ "\n",
+ "First, place your DeepLabCut project folder into you google drive! \"i.e. move the folder named \"Project-YourName-TheDate\" into Google Drive."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "SXzBBV8ehDR9",
+ "outputId": "90d61c19-400b-4e5d-8ac9-63680d72cdb5"
+ },
+ "outputs": [],
+ "source": [
+ "# Now, let's link to your GoogleDrive. Run this cell and follow the\n",
+ "# authorization instructions:\n",
+ "\n",
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "-QmTftBMo4h6"
+ },
+ "source": [
+ "You will need to edit the project path in the config.yaml file to be set to your Google Drive link!\n",
+ "\n",
+ "Typically, this will be in the format: `/content/drive/MyDrive/yourProjectFolderName`. You can obtain this path by going to the file navigator in the left pane, finding your DeepLabCut project folder, clicking on the vertical `...` next to the folder name and selecting \"Copy path\".\n",
+ "\n",
+ "If the `drive` folder is not immediately visible after mounting the drive, refresh the available files!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "_iFFEYAB7Uum"
+ },
+ "outputs": [],
+ "source": [
+ "# TODO: Update the `project_path` to be the path of your DeepLabCut project!\n",
+ "project_path = Path(\"/content/drive/MyDrive/my-project-2024-07-17\")\n",
+ "config_path = str(project_path / \"config.yaml\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HZTG3Eo475w0"
+ },
+ "source": [
+ "Then, use the panel below to select the appropriate SuperAnimal model for your project (don't forget to run the cell)!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "t8NtCy1Jo0bu"
+ },
+ "outputs": [],
+ "source": [
+ "# @markdown ---\n",
+ "# @markdown SuperAnimal Configurations\n",
+ "superanimal_name = \"superanimal_topviewmouse\" #@param [\"superanimal_topviewmouse\", \"superanimal_quadruped\"]\n",
+ "model_name = \"hrnet_w32\" #@param [\"hrnet_w32\"]\n",
+ "detector_name = \"fasterrcnn_resnet50_fpn_v2\" #@param [\"fasterrcnn_resnet50_fpn_v2\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BPvoL9uZ0Z4a"
+ },
+ "source": [
+ "### Comparison between different training baselines\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eVmpaLdB0Z4a"
+ },
+ "source": [
+ "Definition of data split: the unique combination of training images and testing images.\n",
+ "We create a data split named split 0. All baselines will share the data split to make fair comparisons.\n",
+ "- split 0 -> shared by all baselines\n",
+ "- shuffle 0 (split0) -> imagenet transfer learning\n",
+ "- shuffle 1 (split0) -> superanimal transfer learning\n",
+ "- shuffle 2 (split0) -> superanimal naive fine-tuning\n",
+ "- shuffle 3 (split0) -> superanimal memory-replay fine-tuning"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "WofR2jytxPoR"
+ },
+ "source": [
+ "### What is the difference between baselines?\n",
+ "\n",
+ "**Transfer learning** For canonical task-agnostic transfer learning,\n",
+ "the encoder learns universal visual features from a large pre-training dataset, and a randomly\n",
+ "initialized decoder is used to learn the pose from the downstream dataset.\n",
+ "\n",
+ "**Fine-tuning** For task aware\n",
+ "fine-tuning, both encoder and decoder learn task-related visual-pose features\n",
+ "in the pre-training datasets, and the decoder is fine-tuned to update pose\n",
+ "priors in downstream datasets. Crucially, the network has pose-estimation-specific\n",
+ "weights\n",
+ "\n",
+ "**ImageNet transfer-learning** The encoder was pre-trained from ImageNet. The decoder is trained from scratch in the downstream tasks\n",
+ "\n",
+ "**SuperAnimal transfer-learning** The encoder was pre-trained first from ImageNet, then in pose datasets we colleceted. Then decoder is trained from scratch in downstream tasks.\n",
+ "\n",
+ "**SuperAnimal naive fine-tuning** Both the encoder and the decoder were pre-trained in pose datasets we collected. In downstream datasets, we only finetune convolutional channels that correspond to the annotated keypoints in the downstream datasets. This introduces catastrophic forgetting in keypoints that are not annotated in the downstream datasets.\n",
+ "\n",
+ "**SuperAnimal memory-replay fine-tuning** If we apply fine-tuning with SuperAnimal without further cares, the models will forget about keypoints that are not annotated in the downstream datasets. To mitigate this, we mix the annotations and zero-shot predictions of SuperAnimal models to create a dataset that 'replays' the memory of the SuperAnimal keypoints.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "AgIsUu6v0Z4a",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "imagenet_transfer_learning_shuffle = 0\n",
+ "superanimal_transfer_learning_shuffle = 1\n",
+ "superanimal_naive_finetune_shuffle = 2\n",
+ "superanimal_memory_replay_shuffle = 3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "kuKcxM8F0Z4a",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "c7df2943-1e2c-4b85-c20d-8b94a8aabd75"
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.create_training_dataset(\n",
+ " config_path,\n",
+ " Shuffles=[imagenet_transfer_learning_shuffle],\n",
+ " net_type=f\"top_down_{model_name}\",\n",
+ " detector_type=detector_name,\n",
+ " engine=deeplabcut.Engine.PYTORCH,\n",
+ " userfeedback=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_6RncQbr0Z4a"
+ },
+ "source": [
+ "### ImageNet transfer learning\n",
+ "\n",
+ "Historically, the transfer learning using ImageNet weights strategies assumed no “animal pose task priors” in the pretrained\n",
+ "model, a paradigm adopted from previous task-agnostic transfer learning.\n",
+ "\n",
+ "You can change the number of epochs you want to train for. How long training will take depends on many parameters, including the number of images in your dataset, the resolution of the images, and the number of epochs you train for."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000,
+ "referenced_widgets": [
+ "7ed11ae2a4be462da84ff716e0725af0",
+ "0f0ed94a863f49b9b85d0a18fa8ce2a5",
+ "343f2670d37c4bf18859238c3d81d419",
+ "d104ae21091e4f10a7de18e191b9f04d",
+ "5dcbd8f3fb6148cca6cfc72b20ce49bd",
+ "e1675e53ca9a4da8acf6c16fba7a2578",
+ "3d2996e10f96404baf24d2c4215b75a1",
+ "b988f87e676840ee98daa3d996c9ddbc",
+ "1779b84e748b4989a8ed53434c30016f",
+ "d37cf6fe7c444bc2a2568c3407389ea8",
+ "2cef5e028d2e40a6bba7400be922d0c2"
+ ]
+ },
+ "id": "H2z8kM340Z4a",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "75cc2c95-2ac7-4354-9134-4847937e15ce"
+ },
+ "outputs": [],
+ "source": [
+ "# Note we skip the detector training to save time.\n",
+ "# For Top-Down models, the evaluation is by default using ground-truth bounding\n",
+ "# boxes. But to train a model that can be used to inference videos and images,\n",
+ "# you have to set detector_epochs > 0.\n",
+ "\n",
+ "deeplabcut.train_network(\n",
+ " config_path,\n",
+ " detector_epochs=0,\n",
+ " epochs=50,\n",
+ " save_epochs=10,\n",
+ " batch_size=64, # if you get a CUDA OOM error when training on a GPU, reduce to 32, 16, ...!\n",
+ " displayiters=10,\n",
+ " shuffle=imagenet_transfer_learning_shuffle,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "J-udMck7nDbG"
+ },
+ "source": [
+ "Now let's evaluate the performance of our trained models."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "TDHMdKz4m_16",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "1d38fb84-7f4c-45d1-dbcd-fd7117ca4dad"
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.evaluate_network(config_path, Shuffles=[imagenet_transfer_learning_shuffle])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0GIFWU-MxPoR"
+ },
+ "source": [
+ "### Transfer learning with SuperAnimal weights\n",
+ "\n",
+ "First, we prepare training shuffle for transfer-learning with SuperAnimal weights. As we've already create a shuffle with a train/test split that we want to reuse, we use `deeplabcut.create_training_dataset_from_existing_split` to keep the same train/test indices as in the ImageNet transfer learning shuffle.\n",
+ "\n",
+ "We specify that we want to initialize the model weights with the selected SuperAnimal model, but without keeping the decoding layers (this is called transfer learning)!\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "wOSdZQtOp8qa",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "ea721606-ea9f-444b-cdae-f62cf0ad30be"
+ },
+ "outputs": [],
+ "source": [
+ "weight_init = build_weight_init(\n",
+ " cfg=auxiliaryfunctions.read_config(config_path), \n",
+ " super_animal=superanimal_name,\n",
+ " model_name=model_name,\n",
+ " detector_name=detector_name,\n",
+ " with_decoder=False,\n",
+ ")\n",
+ "\n",
+ "deeplabcut.create_training_dataset_from_existing_split(\n",
+ " config_path,\n",
+ " from_shuffle=imagenet_transfer_learning_shuffle,\n",
+ " shuffles=[superanimal_transfer_learning_shuffle],\n",
+ " engine=deeplabcut.Engine.PYTORCH,\n",
+ " net_type=f\"top_down_{model_name}\",\n",
+ " detector_type=detector_name,\n",
+ " weight_init=weight_init,\n",
+ " userfeedback=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "3qFxlRHixPoR"
+ },
+ "source": [
+ "Then, we launch the training for transfer-learning with SuperAnimal weights."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000,
+ "referenced_widgets": [
+ "9a996c8dc3b34bc5b8805b3687e22b27",
+ "d012b421c189412dabeac84cba4164a7",
+ "1abff22a7c9a416d9166e6b150612171",
+ "7271412c1f0141649a7300dbce2b003c",
+ "3c011813d7cb48588a8d236785d9c24f",
+ "3ea385fe815f4e50a0b81ec299040314",
+ "fe59f6c5ed7b4e2cb87bb60224acdaba",
+ "04370d8302c04c5ca6a351383126193f",
+ "d67c4871543e405fbb576a55f8c9048a",
+ "a6cb25fa67ef4733a720960b3fc8213c",
+ "b73b1b64620d492dbc4eaf4bd83ca23a",
+ "dccbe277cc084ed6aa0b329067b5c69c",
+ "c8b57833d3f946abae69b84075345a54",
+ "bee292213d8645618536fcdf6a491d83",
+ "fbbc8c5b20c7423fb21b74296e0eeb28",
+ "ff0c737c49624b1ea27588611951fc84",
+ "42874cdab4be4dc38b0c33775b27d98c",
+ "e3a185abf8a04edabf32d58bdee10dd1",
+ "7cdcbbf9cb694dbf949e8b7eea8e7836",
+ "2ec06260b237411cabd3de7c37e03b1b",
+ "9f8009429aa34b40a65c998230f20c99",
+ "2a3abfe7867641db9fbfe3ee76854bf4"
+ ]
+ },
+ "id": "W60UgRQWqghn",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "18b931b8-98f4-4539-bf82-1910ff5b7f70"
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.train_network(\n",
+ " config_path,\n",
+ " detector_epochs=0,\n",
+ " epochs=50,\n",
+ " save_epochs=10,\n",
+ " batch_size=64, # if you get a CUDA OOM error when training on a GPU, reduce to 32, 16, ...!\n",
+ " displayiters=10,\n",
+ " shuffle=superanimal_transfer_learning_shuffle,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XzOWKiOixPoR"
+ },
+ "source": [
+ "Finally, we evaluate the model obtained by transfer-learning with SuperAnimal weights."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jpO3aIAIsWbz",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "30415e5b-8011-4651-af77-a781ea2b5af7"
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.evaluate_network(config_path, Shuffles=[superanimal_transfer_learning_shuffle])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_Es6RR-_0Z4b"
+ },
+ "source": [
+ "### Fine-tuning with SuperAnimal (without keeping full SuperAnimal keypoints)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "6oo9oJ8XyZrn"
+ },
+ "source": [
+ "#### Setup the weight init and dataset\n",
+ "\n",
+ "First we do keypoint matching. This steps make it possible to understand the correspondence between the existing annotations and SuperAnimal annotations. This step produces 3 outputs\n",
+ "- The confusion matrix\n",
+ "- The conversion table\n",
+ "- Pseudo predictions over the whole dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "fRm62Ji_xPoS"
+ },
+ "source": [
+ "#### What is keypoint matching?\n",
+ "\n",
+ "Because SuperAnimal models have their pre-defined keypoints that are potentially different from your annotations, we proposed this algorithm to minimize the gap between the model and the dataset. We use our model to perform zero-shot inference on the whole dataset. This gives pairs of predictions and ground truth for every image. Then, we cast the matching between models’ predictions (2D coordinates)\n",
+ "and ground truth as bipartitematching using the Euclidean distance as the cost between paired of keypoints. We then solve the matching using the Hungarian algorithm. Thus for every image, we end up getting a matching matrix where 1 counts formatch and 0 counts for non-matching. Because the models’ predictions can be noisy from image to image, we average the aforementioned matching matrix across all the images and perform another bipartite matching, resulting in the final keypoint conversion table between the model and the dataset. Note that the quality of thematching will impact the performance\n",
+ "of the model, especially for zero-shot. In the case where, e.g., the annotation nose is mistakenly converted to keypoint tail and vice versa, the model will have to unlearn the channel that corresponds to nose and tail (see also case study in Mathis et al.)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "vEHeuKSKyjA6",
+ "jupyter": {
+ "outputs_hidden": true
+ },
+ "outputId": "5863a81e-e0b9-48c7-f2f9-de14d38e805e"
+ },
+ "outputs": [],
+ "source": [
+ "keypoint_matching(\n",
+ " config_path,\n",
+ " superanimal_name,\n",
+ " model_name,\n",
+ " detector_name,\n",
+ " copy_images=True,\n",
+ ")\n",
+ "\n",
+ "conversion_table_path = project_path / \"memory_replay\" / \"conversion_table.csv\"\n",
+ "confusion_matrix_path = project_path / \"memory_replay\" / \"confusion_matrix.png\"\n",
+ "\n",
+ "# You can visualize the pseudo predictions, or do pose embedding clustering etc.\n",
+ "pseudo_prediction_path = project_path / \"memory_replay\" / \"pseudo_predictions.json\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "sA8yyLgs0zoO"
+ },
+ "source": [
+ "#### Display the confusion matrix\n",
+ "\n",
+ "The x axis lists the keypoints in the existing annotations. The y axis lists the keypoints in SuperAnimal keypoint space. Darker color encodes stronger correspondence between the human annotation and SuperAnimal annotations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "luDxpD9H0zYZ",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "confusion_matrix_image = Image.open(confusion_matrix_path)\n",
+ "\n",
+ "plt.imshow(confusion_matrix_image)\n",
+ "plt.axis('off') # Hide the axes for better view\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "i0QWikYmy_Mj"
+ },
+ "source": [
+ "#### Display the conversion table\n",
+ "The gt columns represents the keypoint names in the existing dataset. The MasterName represents the corresponding keypoints in SuperAnimal keypoint space."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "CeA-NzDMynYV",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "df = pd.read_csv(conversion_table_path)\n",
+ "df = df.dropna()\n",
+ "\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Adding the Conversion Table to your project's `config.yaml` file\n",
+ "\n",
+ "Once you've run keypoint matching, you can add the conversion table to your project's `config.yaml` file, and edit it if there are some matches you think are wrong. As an example, for a top-view mouse dataset with 4 bodyparts labeled (`'snout', 'leftear', 'rightear', 'tailbase'`), the conversion table mapping project bodyparts to SuperAnimal bodyparts would be added as:\n",
+ "\n",
+ "```yaml\n",
+ "# Conversion tables to fine-tune SuperAnimal weights\n",
+ "SuperAnimalConversionTables:\n",
+ " superanimal_topviewmouse:\n",
+ " snout: nose\n",
+ " leftear: left_ear\n",
+ " rightear: right_ear\n",
+ " tailbase: tail_base\n",
+ "```\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "create_conversion_table(\n",
+ " config=config_path,\n",
+ " super_animal=superanimal_name,\n",
+ " project_to_super_animal=read_conversion_table_from_csv(\n",
+ " conversion_table_path\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "GkfIo8zTxPoS"
+ },
+ "source": [
+ "#### Prepare the training shuffle and weight initialization for (naive) fine-tuning with SuperAnimal weights\n",
+ "\n",
+ "Then, when you call `build_weight_init` with `with_decoder=True`, the conversion table in your project's `config.yaml` is used to get predictions for the correct bodyparts."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "xEeM_hrOu6k8",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "weight_init = build_weight_init(\n",
+ " cfg=auxiliaryfunctions.read_config(config_path), \n",
+ " super_animal=superanimal_name,\n",
+ " model_name=model_name,\n",
+ " detector_name=detector_name,\n",
+ " with_decoder=True,\n",
+ ")\n",
+ "\n",
+ "deeplabcut.create_training_dataset_from_existing_split(\n",
+ " config_path,\n",
+ " from_shuffle=imagenet_transfer_learning_shuffle,\n",
+ " shuffles=[superanimal_naive_finetune_shuffle],\n",
+ " engine=deeplabcut.Engine.PYTORCH,\n",
+ " net_type=f\"top_down_{model_name}\",\n",
+ " detector_type=detector_name,\n",
+ " weight_init=weight_init,\n",
+ " userfeedback=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gZx6nr-ExPoS"
+ },
+ "source": [
+ "#### Launch the training for (naive) fine-tuning with SuperAnimal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "c3XAr6uRyXOD",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.train_network(\n",
+ " config_path,\n",
+ " detector_epochs=0,\n",
+ " epochs=50,\n",
+ " save_epochs=10,\n",
+ " batch_size=64, # if you get a CUDA OOM error when training on a GPU, reduce to 32, 16, ...!\n",
+ " displayiters=10,\n",
+ " shuffle=superanimal_naive_finetune_shuffle,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "oXuRshzhxPoS"
+ },
+ "source": [
+ "#### Evaluate the model obtained by (naive) fine-tuning with SuperAnimal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "VXfdKS-H2yqw",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.evaluate_network(\n",
+ " config_path,\n",
+ " Shuffles=[superanimal_naive_finetune_shuffle],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "_nUAMlbZ0Z4b"
+ },
+ "source": [
+ "### Memory-replay fine-tuning with SuperAnimal (keeping full SuperAnimal keypoints)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "n6HPu6RaxPoS"
+ },
+ "source": [
+ "**Catastrophic forgetting** describes a\n",
+ "classic problemin continual learning. Indeed, amodel gradually loses\n",
+ "its ability to solve previous tasks after it learns to solve new ones.\n",
+ "Fine-tuning a SuperAnimal models falls into the category of continual\n",
+ "learning: the downstream dataset defines potentially different\n",
+ "keypoints than those learned by the models. Thus, the models might\n",
+ "forget the keypoints they learned and only pick up those defined in the\n",
+ "target dataset. Here, retraining with the original dataset and the new\n",
+ "one, is not a feasible option as datasets cannot be easily shared and\n",
+ "more computational resources would be required.\n",
+ "To counter that, we treat zero-shot inference of the model as a\n",
+ "memory buffer that stores knowledge from the original model. When\n",
+ "we fine-tune a SuperAnimal model, we replace the model predicted\n",
+ "keypoints with the ground-truth annotations, resulting in hybrid\n",
+ "learning of old and new knowledge. The quality of the zero-shot predictions\n",
+ "can vary and we use the confidence of prediction (0.7) as a\n",
+ "threshold to filter out low-confidence predictions. With the threshold\n",
+ "set to 1, memory replay fine-tuning becomes naive-fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "CSLmjlCIxPoS"
+ },
+ "source": [
+ "#### Prepare training shuffle and weight initialization for memory-replay finetuning with SuperAnimal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "BKEF76AI0Z4c",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "weight_init = build_weight_init(\n",
+ " cfg=auxiliaryfunctions.read_config(config_path), \n",
+ " super_animal=superanimal_name,\n",
+ " model_name=model_name,\n",
+ " detector_name=detector_name,\n",
+ " with_decoder=True,\n",
+ " memory_replay=True,\n",
+ ")\n",
+ "\n",
+ "deeplabcut.create_training_dataset_from_existing_split(\n",
+ " config_path,\n",
+ " from_shuffle=imagenet_transfer_learning_shuffle,\n",
+ " shuffles=[superanimal_memory_replay_shuffle],\n",
+ " engine=deeplabcut.Engine.PYTORCH,\n",
+ " net_type=f\"top_down_{model_name}\",\n",
+ " detector_type=detector_name,\n",
+ " weight_init=weight_init,\n",
+ " userfeedback=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "MKwJiIyKxPoT"
+ },
+ "source": [
+ "#### Launch the training for memory-replay fine-tuning with SuperAnimal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Ru8tIFmD2Mkv",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.train_network(\n",
+ " config_path,\n",
+ " detector_epochs=0,\n",
+ " epochs=50,\n",
+ " save_epochs=10,\n",
+ " batch_size=64, # if you get a CUDA OOM error when training on a GPU, reduce to 32, 16, ...!\n",
+ " displayiters=10,\n",
+ " shuffle=superanimal_memory_replay_shuffle,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "i-2MBRDjxPoT"
+ },
+ "source": [
+ "#### Evaluate the model obtained by memory-replay finetuning with SuperAnimal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "sfMcK3gq8WxZ",
+ "jupyter": {
+ "outputs_hidden": true
+ }
+ },
+ "outputs": [],
+ "source": [
+ "deeplabcut.evaluate_network(config_path, Shuffles=[superanimal_memory_replay_shuffle])"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [
+ "UeXjmtu40Z4X",
+ "FvFzntDMxPoL",
+ "6VEjHu-00Z4Y"
+ ],
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "04370d8302c04c5ca6a351383126193f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0f0ed94a863f49b9b85d0a18fa8ce2a5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e1675e53ca9a4da8acf6c16fba7a2578",
+ "placeholder": "",
+ "style": "IPY_MODEL_3d2996e10f96404baf24d2c4215b75a1",
+ "value": "model.safetensors: 100%"
+ }
+ },
+ "1779b84e748b4989a8ed53434c30016f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "1abff22a7c9a416d9166e6b150612171": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_04370d8302c04c5ca6a351383126193f",
+ "max": 159594859,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_d67c4871543e405fbb576a55f8c9048a",
+ "value": 159594859
+ }
+ },
+ "2a3abfe7867641db9fbfe3ee76854bf4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2cef5e028d2e40a6bba7400be922d0c2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2ec06260b237411cabd3de7c37e03b1b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "343f2670d37c4bf18859238c3d81d419": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_b988f87e676840ee98daa3d996c9ddbc",
+ "max": 165432914,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_1779b84e748b4989a8ed53434c30016f",
+ "value": 165432914
+ }
+ },
+ "3c011813d7cb48588a8d236785d9c24f": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "3d2996e10f96404baf24d2c4215b75a1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "3ea385fe815f4e50a0b81ec299040314": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "42874cdab4be4dc38b0c33775b27d98c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5dcbd8f3fb6148cca6cfc72b20ce49bd": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7271412c1f0141649a7300dbce2b003c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a6cb25fa67ef4733a720960b3fc8213c",
+ "placeholder": "",
+ "style": "IPY_MODEL_b73b1b64620d492dbc4eaf4bd83ca23a",
+ "value": " 160M/160M [00:00<00:00, 201MB/s]"
+ }
+ },
+ "7cdcbbf9cb694dbf949e8b7eea8e7836": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7ed11ae2a4be462da84ff716e0725af0": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0f0ed94a863f49b9b85d0a18fa8ce2a5",
+ "IPY_MODEL_343f2670d37c4bf18859238c3d81d419",
+ "IPY_MODEL_d104ae21091e4f10a7de18e191b9f04d"
+ ],
+ "layout": "IPY_MODEL_5dcbd8f3fb6148cca6cfc72b20ce49bd"
+ }
+ },
+ "9a996c8dc3b34bc5b8805b3687e22b27": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_d012b421c189412dabeac84cba4164a7",
+ "IPY_MODEL_1abff22a7c9a416d9166e6b150612171",
+ "IPY_MODEL_7271412c1f0141649a7300dbce2b003c"
+ ],
+ "layout": "IPY_MODEL_3c011813d7cb48588a8d236785d9c24f"
+ }
+ },
+ "9f8009429aa34b40a65c998230f20c99": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a6cb25fa67ef4733a720960b3fc8213c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "b73b1b64620d492dbc4eaf4bd83ca23a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b988f87e676840ee98daa3d996c9ddbc": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "bee292213d8645618536fcdf6a491d83": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7cdcbbf9cb694dbf949e8b7eea8e7836",
+ "max": 517816013,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_2ec06260b237411cabd3de7c37e03b1b",
+ "value": 517816013
+ }
+ },
+ "c8b57833d3f946abae69b84075345a54": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_42874cdab4be4dc38b0c33775b27d98c",
+ "placeholder": "",
+ "style": "IPY_MODEL_e3a185abf8a04edabf32d58bdee10dd1",
+ "value": "detector.pt: 100%"
+ }
+ },
+ "d012b421c189412dabeac84cba4164a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_3ea385fe815f4e50a0b81ec299040314",
+ "placeholder": "",
+ "style": "IPY_MODEL_fe59f6c5ed7b4e2cb87bb60224acdaba",
+ "value": "pose_model.pth: 100%"
+ }
+ },
+ "d104ae21091e4f10a7de18e191b9f04d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d37cf6fe7c444bc2a2568c3407389ea8",
+ "placeholder": "",
+ "style": "IPY_MODEL_2cef5e028d2e40a6bba7400be922d0c2",
+ "value": " 165M/165M [00:04<00:00, 41.1MB/s]"
+ }
+ },
+ "d37cf6fe7c444bc2a2568c3407389ea8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d67c4871543e405fbb576a55f8c9048a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "dccbe277cc084ed6aa0b329067b5c69c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_c8b57833d3f946abae69b84075345a54",
+ "IPY_MODEL_bee292213d8645618536fcdf6a491d83",
+ "IPY_MODEL_fbbc8c5b20c7423fb21b74296e0eeb28"
+ ],
+ "layout": "IPY_MODEL_ff0c737c49624b1ea27588611951fc84"
+ }
+ },
+ "e1675e53ca9a4da8acf6c16fba7a2578": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e3a185abf8a04edabf32d58bdee10dd1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "fbbc8c5b20c7423fb21b74296e0eeb28": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_9f8009429aa34b40a65c998230f20c99",
+ "placeholder": "",
+ "style": "IPY_MODEL_2a3abfe7867641db9fbfe3ee76854bf4",
+ "value": " 518M/518M [00:05<00:00, 101MB/s]"
+ }
+ },
+ "fe59f6c5ed7b4e2cb87bb60224acdaba": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ff0c737c49624b1ea27588611951fc84": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb b/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb
index 989655c7a5..5a92ec0892 100644
--- a/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb
+++ b/examples/COLAB/COLAB_YOURDATA_TrainNetwork_VideoAnalysis.ipynb
@@ -18,7 +18,12 @@
},
"source": [
"# DeepLabCut Toolbox - Colab for standard (single animal) projects!\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "\n",
+ "Some useful links:\n",
+ "\n",
+ "- [DeepLabCut's GitHub: github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Single Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html)\n",
+ "\n",
"\n",
"This notebook illustrates how to use the cloud to:\n",
"- create a training set\n",
@@ -59,50 +64,41 @@
"metadata": {},
"outputs": [],
"source": [
- "# Install TensorFlow, tensorpack and tf_slim versions compatible with DeepLabCut\n",
- "!pip install \"tensorflow==2.12.1\" \"tensorpack>=0.11\" \"tf_slim>=1.1.0\""
+ "# this will take a couple of minutes to install all the dependencies!\n",
+ "!pip install --pre deeplabcut"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Downgrade PyTorch to a version using CUDA 11.8 and cudnn 8\n",
- "# This will also install the required CUDA libraries, for both PyTorch and TensorFlow\n",
- "!pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cu118"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "25wSj6TlVclR"
+ },
"source": [
- "# Install the latest version of DeepLabCut\n",
- "!pip install \"git+https://github.com/DeepLabCut/DeepLabCut.git\""
+ "**(Be sure to click \"RESTART RUNTIME\" if it is displayed above before moving on !)** You will see this button at the output of the cells above ^."
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# As described in https://www.tensorflow.org/install/pip#step-by-step_instructions, \n",
- "# create symbolic links to NVIDIA shared libraries:\n",
- "!ln -svf /usr/local/lib/python3.11/dist-packages/nvidia/*/lib/*.so* /usr/local/lib/python3.11/dist-packages/tensorflow"
- ]
- },
- {
- "cell_type": "markdown",
+ "execution_count": 2,
"metadata": {
- "colab_type": "text",
- "id": "25wSj6TlVclR"
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "oTwAcbq2-FZz",
+ "outputId": "9cfd8dcf-a0a8-4801-ed1d-fbcd5ec056af"
},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "DLC loaded in light mode; you cannot use any GUI (labeling, relabeling and standalone GUI)\n"
+ ]
+ }
+ ],
"source": [
- "**(Be sure to click \"RESTART RUNTIME\" if it is displayed above before moving on !)**"
+ "import deeplabcut"
]
},
{
@@ -127,11 +123,12 @@
},
"outputs": [],
"source": [
- "#Now, let's link to your GoogleDrive. Run this cell and follow the authorization instructions:\n",
- "#(We recommend putting a copy of the github repo in your google drive if you are using the demo \"examples\")\n",
+ "# Now, let's link to your GoogleDrive. Run this cell and follow the authorization instructions:\n",
+ "# (We recommend putting a copy of the github repo in your google drive if you are using the demo \"examples\")\n",
"\n",
"from google.colab import drive\n",
- "drive.mount('/content/drive')"
+ "\n",
+ "drive.mount(\"/content/drive\")"
]
},
{
@@ -143,7 +140,7 @@
"source": [
"YOU WILL NEED TO EDIT THE PROJECT PATH **in the config.yaml file** TO BE SET TO YOUR GOOGLE DRIVE LINK!\n",
"\n",
- "Typically, this will be: /content/drive/My Drive/yourProjectFolderName\n"
+ "Typically, this will be: `/content/drive/My Drive/yourProjectFolderName`\n"
]
},
{
@@ -156,57 +153,25 @@
},
"outputs": [],
"source": [
- "#Setup your project variables:\n",
- "# PLEASE EDIT THESE:\n",
- " \n",
- "ProjectFolderName = 'myproject-teamDLC-2020-03-29'\n",
- "VideoType = 'mp4' \n",
+ "# PLEASE EDIT THIS:\n",
+ "project_folder_name = \"MontBlanc-Daniel-2019-12-16\"\n",
+ "video_type = \"mp4\" #, mp4, MOV, or avi, whatever you uploaded!\n",
"\n",
- "#don't edit these:\n",
- "videofile_path = ['/content/drive/My Drive/'+ProjectFolderName+'/videos/'] #Enter the list of videos or folder to analyze.\n",
- "videofile_path"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "3K9Ndy1beyfG"
- },
- "outputs": [],
- "source": [
- "import deeplabcut"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "o4orkg9QTHKK"
- },
- "outputs": [],
- "source": [
- "deeplabcut.__version__"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "Z7ZlDr3wV4D1"
- },
- "outputs": [],
- "source": [
- "#This creates a path variable that links to your google drive copy\n",
- "#No need to edit this, as you set it up before: \n",
- "path_config_file = '/content/drive/My Drive/'+ProjectFolderName+'/config.yaml'\n",
- "path_config_file"
+ "# No need to edit this, we are going to assume you put videos you want to analyze\n",
+ "# in the \"videos\" folder, but if this is NOT true, edit below:\n",
+ "videofile_path = [f\"/content/drive/My Drive/{project_folder_name}/videos/\"]\n",
+ "print(videofile_path)\n",
+ "\n",
+ "# The prediction files and labeled videos will be saved in this `labeled-videos` folder\n",
+ "# in your project folder; if you want them elsewhere, you can edit this;\n",
+ "# if you want the output files in the same folder as the videos, set this to an empty string.\n",
+ "destfolder = f\"/content/drive/My Drive/{project_folder_name}/labeled-videos\"\n",
+ "\n",
+ "#No need to edit this, as you set it when you passed the ProjectFolderName (above):\n",
+ "path_config_file = f\"/content/drive/My Drive/{project_folder_name}/config.yaml\"\n",
+ "print(path_config_file)\n",
+ "\n",
+ "# This creates a path variable that links to your Google Drive project"
]
},
{
@@ -217,10 +182,12 @@
},
"source": [
"## Create a training dataset:\n",
- "### You must do this step inside of Colab:\n",
+ "\n",
+ "### You must do this step inside of Colab\n",
+ "\n",
"After running this script the training dataset is created and saved in the project directory under the subdirectory **'training-datasets'**\n",
"\n",
- "This function also creates new subdirectories under **dlc-models** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pose_cfg.yaml**.\n",
+ "This function also creates new subdirectories under **dlc-models-pytorch** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pytorch_config.yaml**.\n",
"\n",
"Now it is the time to start training the network!"
]
@@ -236,10 +203,12 @@
},
"outputs": [],
"source": [
- "# Note: if you are using the demo data (i.e. examples/Reaching-Mackenzie-2018-08-30/), first delete the folder called dlc-models! \n",
- "#Then, run this cell. There are many more functions you can set here, including which netowkr to use!\n",
- "#check the docstring for full options you can do!\n",
- "deeplabcut.create_training_dataset(path_config_file, net_type='resnet_50', augmenter_type='imgaug')"
+ "# There are many more functions you can set here, including which network to use!\n",
+ "# Check the docstring for `create_training_dataset` for all options you can use!\n",
+ "\n",
+ "deeplabcut.create_training_dataset(\n",
+ " path_config_file, net_type=\"resnet_50\", engine=deeplabcut.Engine.PYTORCH\n",
+ ")"
]
},
{
@@ -263,14 +232,26 @@
},
"outputs": [],
"source": [
- "#let's also change the display and save_iters just in case Colab takes away the GPU... \n",
- "#if that happens, you can reload from a saved point. Typically, you want to train to 200,000 + iterations.\n",
- "#more info and there are more things you can set: https://github.com/DeepLabCut/DeepLabCut/wiki/DOCSTRINGS#train_network\n",
+ "# Let's also change the display and save_epochs just in case Colab takes away\n",
+ "# the GPU... If that happens, you can reload from a saved point using the\n",
+ "# `snapshot_path` argument to `deeplabcut.train_network`:\n",
+ "# deeplabcut.train_network(..., snapshot_path=\"/content/.../snapshot-050.pt\")\n",
+ "\n",
+ "# Typically, you want to train to ~200 epochs. We set the batch size to 8 to\n",
+ "# utilize the GPU's capabilities.\n",
+ "\n",
+ "# More info and there are more things you can set:\n",
+ "# https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#g-train-the-network\n",
"\n",
- "deeplabcut.train_network(path_config_file, shuffle=1, displayiters=10,saveiters=500)\n",
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=5,\n",
+ " epochs=200,\n",
+ " batch_size=8,\n",
+ ")\n",
"\n",
- "#this will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end (default, 1.03M iterations). \n",
- "#Whichever you chose, you will see what looks like an error message, but it's not an error - don't worry...."
+ "# This will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end."
]
},
{
@@ -280,7 +261,7 @@
"id": "RiDwIVf5-3H_"
},
"source": [
- "**When you hit \"STOP\" you will get a KeyInterrupt \"error\"! No worries! :)**"
+ "Note, that **when you hit \"STOP\" you will get a `KeyboardInterrupt` \"error\"! No worries! :)**"
]
},
{
@@ -292,7 +273,7 @@
"source": [
"## Start evaluating:\n",
"This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images)\n",
- "and stores the results as .csv file in a subdirectory under **evaluation-results**"
+ "and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**"
]
},
{
@@ -305,11 +286,10 @@
},
"outputs": [],
"source": [
- "%matplotlib notebook\n",
- "deeplabcut.evaluate_network(path_config_file,plotting=True)\n",
+ "deeplabcut.evaluate_network(path_config_file, plotting=True)\n",
"\n",
- "# Here you want to see a low pixel error! Of course, it can only be as good as the labeler, \n",
- "#so be sure your labels are good! (And you have trained enough ;)"
+ "# Here you want to see a low pixel error! Of course, it can only be as\n",
+ "# good as the labeler, so be sure your labels are good!\n"
]
},
{
@@ -348,7 +328,12 @@
},
"outputs": [],
"source": [
- "deeplabcut.analyze_videos(path_config_file,videofile_path, videotype=VideoType)"
+ "deeplabcut.analyze_videos(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " videotype=video_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -372,7 +357,12 @@
},
"outputs": [],
"source": [
- "deeplabcut.plot_trajectories(path_config_file,videofile_path, videotype=VideoType)"
+ "deeplabcut.plot_trajectories(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " videotype=video_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -406,7 +396,12 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,videofile_path, videotype=VideoType)"
+ "deeplabcut.create_labeled_video(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " videotype=video_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
}
],
diff --git a/examples/COLAB/COLAB_maDLC_TrainNetwork_VideoAnalysis.ipynb b/examples/COLAB/COLAB_YOURDATA_maDLC_TrainNetwork_VideoAnalysis.ipynb
similarity index 57%
rename from examples/COLAB/COLAB_maDLC_TrainNetwork_VideoAnalysis.ipynb
rename to examples/COLAB/COLAB_YOURDATA_maDLC_TrainNetwork_VideoAnalysis.ipynb
index 3da23bd633..903b60b115 100644
--- a/examples/COLAB/COLAB_maDLC_TrainNetwork_VideoAnalysis.ipynb
+++ b/examples/COLAB/COLAB_YOURDATA_maDLC_TrainNetwork_VideoAnalysis.ipynb
@@ -16,10 +16,15 @@
"id": "RK255E7YoEIt"
},
"source": [
- "# DeepLabCut 2.2+ Toolbox - COLAB\n",
- "\n",
+ "# DeepLabCut 3.0+ Toolbox - COLAB for multi-animal projects!\n",
+ "\n",
+ "Some useful links:\n",
"\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "- [DeepLabCut's GitHub: github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Multi-Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html)\n",
+ "\n",
+ "\n",
+ "\n",
"\n",
"This notebook illustrates how to, for multi-animal projects, use the cloud-based GPU to:\n",
"- create a multi-animal training set\n",
@@ -57,49 +62,15 @@
"metadata": {},
"outputs": [],
"source": [
- "# Install TensorFlow, tensorpack and tf_slim versions compatible with DeepLabCut\n",
- "!pip install \"tensorflow==2.12.1\" \"tensorpack>=0.11\" \"tf_slim>=1.1.0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Downgrade PyTorch to a version using CUDA 11.8 and cudnn 8\n",
- "# This will also install the required CUDA libraries, for both PyTorch and TensorFlow\n",
- "!pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cu118"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Install the latest version of DeepLabCut\n",
- "!pip install \"git+https://github.com/DeepLabCut/DeepLabCut.git\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# As described in https://www.tensorflow.org/install/pip#step-by-step_instructions, \n",
- "# create symbolic links to NVIDIA shared libraries:\n",
- "!ln -svf /usr/local/lib/python3.11/dist-packages/nvidia/*/lib/*.so* /usr/local/lib/python3.11/dist-packages/tensorflow"
+ "# this will take a couple of minutes to install all the dependencies!\n",
+ "!pip install --pre deeplabcut"
]
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "bqUEb8TBdpWb"
- },
+ "metadata": {},
"source": [
- "After the package is installed, please click \"restart runtime\" if it appears for DLC changes to take effect in your COLAB environment. You will see this button at the output of the cells above ^."
+ "**(Be sure to click \"RESTART RUNTIME\" if it is displayed above before moving on !)** You will see this button at the output of the cells above ^."
]
},
{
@@ -112,15 +83,7 @@
"id": "oTwAcbq2-FZz",
"outputId": "9cfd8dcf-a0a8-4801-ed1d-fbcd5ec056af"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "DLC loaded in light mode; you cannot use any GUI (labeling, relabeling and standalone GUI)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import deeplabcut"
]
@@ -133,10 +96,8 @@
"source": [
"## Link your Google Drive (with your labeled data):\n",
"\n",
- "- This code assumes you locally installed DeepLabCut, created a project, extracted and labeled frames. Be sure to \"check Labels\" to confirm you are happy with your data. As, these frames are the only thing that is used to train your network. 💪 You can find all the docs to do this here: https://deeplabcut.github.io/DeepLabCut\n",
- "\n",
+ "- This code assumes you locally installed DeepLabCut, created a project, extracted and labeled frames. Be sure to \"check Labels\" to confirm you are happy with your data. As, these frames are the only thing that is used to train your network. 💪 You can find all the docs to do this here: [deeplabcut.github.io/DeepLabCut](https://deeplabcut.github.io/DeepLabCut/README.html)\n",
"- Next, place your DLC project folder into you Google Drive- i.e., copy the folder named \"Project-YourName-TheDate\" into Google Drive.\n",
- "\n",
"- Then, click run on the cell below to link this notebook to your Google Drive:"
]
},
@@ -148,11 +109,12 @@
},
"outputs": [],
"source": [
- "#Now, let's link to your Google Drive. Run this cell and follow the authorization instructions:\n",
- "#(We recommend putting a copy of the github repo in your google drive if you are using the demo \"examples\")\n",
+ "# Now, let's link to your GoogleDrive. Run this cell and follow the authorization instructions:\n",
+ "# (We recommend putting a copy of the github repo in your google drive if you are using the demo \"examples\")\n",
"\n",
"from google.colab import drive\n",
- "drive.mount('/content/drive')"
+ "\n",
+ "drive.mount(\"/content/drive\")"
]
},
{
@@ -161,7 +123,9 @@
"id": "Frnj1RVDyEqs"
},
"source": [
- "## Next, edit the few items below, and click run:\n"
+ "## Next, edit the few items below, and click run:\n",
+ "\n",
+ "YOU WILL NEED TO EDIT THE PROJECT PATH **in the `config.yaml` file** TO BE SET TO YOUR GOOGLE DRIVE LINK! Typically, this will be: `/content/drive/My Drive/yourProjectFolderName`\n"
]
},
{
@@ -173,18 +137,24 @@
"outputs": [],
"source": [
"# PLEASE EDIT THIS:\n",
- "ProjectFolderName = 'MontBlanc-Daniel-2019-12-16'\n",
- "VideoType = 'mp4' #, mp4, MOV, or avi, whatever you uploaded!\n",
+ "project_folder_name = \"MontBlanc-Daniel-2019-12-16\"\n",
+ "video_type = \"mp4\" #, mp4, MOV, or avi, whatever you uploaded!\n",
"\n",
+ "# No need to edit this, we are going to assume you put videos you want to analyze\n",
+ "# in the \"videos\" folder, but if this is NOT true, edit below:\n",
+ "videofile_path = [f\"/content/drive/My Drive/{project_folder_name}/videos/\"]\n",
+ "print(videofile_path)\n",
"\n",
- "# No need to edit this, we are going to assume you put videos you want to analyze in the \"videos\" folder, but if this is NOT true, edit below:\n",
- "videofile_path = ['/content/drive/My Drive/'+ProjectFolderName+'/videos/'] #Enter the list of videos or folder to analyze.\n",
- "videofile_path\n",
+ "# The prediction files and labeled videos will be saved in this `labeled-videos` folder\n",
+ "# in your project folder; if you want them elsewhere, you can edit this;\n",
+ "# if you want the output files in the same folder as the videos, set this to an empty string.\n",
+ "destfolder = f\"/content/drive/My Drive/{project_folder_name}/labeled-videos\"\n",
"\n",
- "#No need to edit this, as you set it when you passed the ProjectFolderName (above): \n",
- "path_config_file = '/content/drive/My Drive/'+ProjectFolderName+'/config.yaml'\n",
- "path_config_file\n",
- "#This creates a path variable that links to your google drive project"
+ "#No need to edit this, as you set it when you passed the ProjectFolderName (above):\n",
+ "path_config_file = f\"/content/drive/My Drive/{project_folder_name}/config.yaml\"\n",
+ "print(path_config_file)\n",
+ "\n",
+ "# This creates a path variable that links to your Google Drive project"
]
},
{
@@ -195,8 +165,7 @@
"source": [
"## Create a multi-animal training dataset:\n",
"\n",
- "- more info: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#create-training-dataset\n",
- "\n",
+ "- more info can be [found in the docs](https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#create-training-dataset)\n",
"- please check the text below, edit if needed, and then click run (this can take some time):"
]
},
@@ -208,7 +177,7 @@
},
"outputs": [],
"source": [
- "#OPTIONAL LEARNING: did you know you can check what each function does by running with a ?\n",
+ "# OPTIONAL LEARNING: did you know you can check what each function does by running with a ?\n",
"deeplabcut.create_multianimaltraining_dataset?"
]
},
@@ -222,11 +191,15 @@
"outputs": [],
"source": [
"# ATTENTION:\n",
- "#which shuffle do you want to create and train?\n",
- "shuffle = 1 #edit if needed; 1 is the default.\n",
+ "# Which shuffle do you want to create and train?\n",
+ "shuffle = 1 # Edit if needed; 1 is the default.\n",
"\n",
- "#if you labeled on Windows, please set the windows2linux=True:\n",
- "deeplabcut.create_multianimaltraining_dataset(path_config_file, Shuffles=[shuffle], net_type=\"dlcrnet_ms5\",windows2linux=False)"
+ "deeplabcut.create_multianimaltraining_dataset(\n",
+ " path_config_file,\n",
+ " Shuffles=[shuffle],\n",
+ " net_type=\"dlcrnet_ms5\",\n",
+ " engine=deeplabcut.Engine.PYTORCH,\n",
+ ")"
]
},
{
@@ -236,8 +209,7 @@
},
"source": [
"## Start training:\n",
- "This function trains the network for a specific shuffle of the training dataset. \n",
- " - more info: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#train-the-network"
+ "This function trains the network for a specific shuffle of the training dataset. More info can be found [in the docs](https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#train-the-network)."
]
},
{
@@ -248,14 +220,26 @@
},
"outputs": [],
"source": [
- "#let's also change the display and save_iters just in case Colab takes away the GPU... \n",
- "#Typically, you want to train to 50,000 - 200K iterations.\n",
- "#more info and there are more things you can set: https://github.com/DeepLabCut/DeepLabCut/blob/master/docs/functionDetails.md#g-train-the-network\n",
+ "# Let's also change the display and save_epochs just in case Colab takes away\n",
+ "# the GPU... If that happens, you can reload from a saved point using the\n",
+ "# `snapshot_path` argument to `deeplabcut.train_network`:\n",
+ "# deeplabcut.train_network(..., snapshot_path=\"/content/.../snapshot-050.pt\")\n",
+ "\n",
+ "# Typically, you want to train to ~200 epochs. We set the batch size to 8 to\n",
+ "# utilize the GPU's capabilities.\n",
"\n",
- "deeplabcut.train_network(path_config_file, shuffle=shuffle, displayiters=100,saveiters=1000, maxiters=75000, allow_growth=True)\n",
+ "# More info and there are more things you can set:\n",
+ "# https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html#g-train-the-network\n",
"\n",
- "#this will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end (default, 50K iterations). \n",
- "#Whichever you chose, you will see what looks like an error message, but it's not an error - don't worry...."
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=shuffle,\n",
+ " save_epochs=5,\n",
+ " epochs=200,\n",
+ " batch_size=8,\n",
+ ")\n",
+ "\n",
+ "# This will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end."
]
},
{
@@ -264,7 +248,7 @@
"id": "RiDwIVf5-3H_"
},
"source": [
- "**When you hit \"STOP\" you will get a KeyInterrupt \"error\"! No worries! :)**"
+ "Note, that **when you hit \"STOP\" you will get a `KeyboardInterrupt` \"error\"! No worries! :)**"
]
},
{
@@ -275,13 +259,10 @@
"source": [
"## Start evaluating: \n",
"\n",
- " - First, we evaluate the pose estimation performance.\n",
- "\n",
- "- This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images) and stores the results as .5 and .csv file in a subdirectory under **evaluation-results**\n",
- "\n",
+ "- First, we evaluate the pose estimation performance.\n",
+ "- This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images) and stores the results as .5 and .csv file in a subdirectory under **evaluation-results-pytorch**\n",
"- If the scoremaps do not look accurate, don't proceed to tracklet assembly; please consider (1) adding more data, (2) adding more bodyparts!\n",
- "\n",
- "- more info: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#evaluate-the-trained-network\n",
+ "- More info can be [found in the docs](https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#evaluate-the-trained-network)\n",
"\n",
"Here is an example of what you'd aim to see before proceeding:\n",
"\n",
@@ -297,10 +278,11 @@
},
"outputs": [],
"source": [
- "#let's evaluate first:\n",
- "deeplabcut.evaluate_network(path_config_file,Shuffles=[shuffle], plotting=True)\n",
- "#plot a few scoremaps:\n",
- "deeplabcut.extract_save_all_maps(path_config_file, shuffle=shuffle, Indices=[0])"
+ "# Let's evaluate first:\n",
+ "deeplabcut.evaluate_network(path_config_file, Shuffles=[shuffle], plotting=True)\n",
+ "\n",
+ "# plot a few scoremaps:\n",
+ "deeplabcut.extract_save_all_maps(path_config_file, shuffle=shuffle, Indices=[0, 1, 2, 3])"
]
},
{
@@ -336,7 +318,14 @@
"#EDIT OPTION: which video(s) do you want to analyze? You can pass a path or a folder:\n",
"# currently, if you run \"as is\" it assumes you have a video in the DLC project video folder!\n",
"\n",
- "deeplabcut.analyze_videos(path_config_file,videofile_path, shuffle=shuffle, videotype=VideoType)"
+ "deeplabcut.analyze_videos(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " shuffle=shuffle,\n",
+ " videotype=video_type,\n",
+ " auto_track=False,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -345,7 +334,7 @@
"id": "91xBLOcBzGxo"
},
"source": [
- "Optional: Now you have the option to check the raw detections before animals are assembled. To do so, pass a video path:"
+ "Optional: Now you have the option to check the raw detections before animals are tracked. To do so, pass a video path:"
]
},
{
@@ -360,11 +349,13 @@
"## look at the output video; if the pose estimation (i.e. key points)\n",
"## don't look good, don't proceed with tracking - add more data to your training set and re-train!\n",
"\n",
- "#EDIT: let's check a specific video (PLEASE EDIT VIDEO PATH):\n",
- "Specific_videofile = '/content/drive/MyDrive/DeepLabCut_maDLC_DemoData/MontBlanc-Daniel-2019-12-16/videos/short.mov'\n",
+ "# EDIT: let's check a specific video (PLEASE EDIT VIDEO PATH):\n",
+ "specific_videofile = \"/content/drive/MyDrive/DeepLabCut_maDLC_DemoData/MontBlanc-Daniel-2019-12-16/videos/short.mov\"\n",
"\n",
- "#don't edit:\n",
- "deeplabcut.create_video_with_all_detections(path_config_file, [Specific_videofile], shuffle=shuffle)"
+ "# Don't edit:\n",
+ "deeplabcut.create_video_with_all_detections(\n",
+ " path_config_file, [specific_videofile], shuffle=shuffle, destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -373,7 +364,7 @@
"id": "3-OgTJ0Lz20e"
},
"source": [
- "If the resulting video (ends in full.mp4) is not good, we highly recommend adding more data and training again. See here: https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#decision-break-point"
+ "If the resulting video (ends in full.mp4) is not good, we highly recommend adding more data and training again. See [here, in the docs](https://deeplabcut.github.io/DeepLabCut/docs/maDLC_UserGuide.html#decision-break-point)."
]
},
{
@@ -382,14 +373,15 @@
"id": "PxRLS2_-r55K"
},
"source": [
- "# Next, we will assemble animals using our data-driven optimal graph method:\n",
+ "## Next, we will assemble animals using our data-driven optimal graph method:\n",
"\n",
- "- Here, we will find the optimal graph, which matches the \"data-driven\" method from our paper (Figure adapted from Lauer et al. 2021):\n",
+ "During video analysis, animals are assembled using the optimal graph, which matches the \"data-driven\" method from our paper (Figure adapted from Lauer et al. 2021)\n",
"\n",
"\n",
"\n",
+ "The optimal graph is computed when `evaluate_network` - so make sure you don't skip that step!\n",
"\n",
- "- note, you can set the number of animals you expect to see, so check, edit, then click run:"
+ "**Note**: you can set the number of animals you expect to see, so check, edit, then click run:"
]
},
{
@@ -401,23 +393,37 @@
"outputs": [],
"source": [
"#Check and edit:\n",
- "numAnimals = 4 #how many animals do you expect to find?\n",
- "tracktype= 'box' #box, skeleton, ellipse:\n",
- "#-- ellipse is recommended, unless you have a single-point ma project, then use BOX!\n",
- "\n",
- "#Optional: \n",
- "#imagine you tracked a point that is not useful for assembly, \n",
- "#like a tail tip that is far from the body, consider dropping it for this step (it's still used later)!\n",
- "#To drop it, uncomment the next line TWO lines and add your parts(s):\n",
- "\n",
- "#bodypart= 'Tail_end'\n",
- "#deeplabcut.convert_detections2tracklets(path_config_file, videofile_path, videotype=VideoType, shuffle=shuffle, overwrite=True, ignore_bodyparts=[bodypart])\n",
- "\n",
- "#OR don't drop, just click RUN:\n",
- "deeplabcut.convert_detections2tracklets(path_config_file, videofile_path, videotype=VideoType, \n",
- " shuffle=shuffle, overwrite=True)\n",
- "\n",
- "deeplabcut.stitch_tracklets(path_config_file, videofile_path, shuffle=shuffle, track_method=tracktype, n_tracks=numAnimals)"
+ "num_animals = 4 # How many animals do you expect to find?\n",
+ "track_type= \"box\" # box, skeleton, ellipse\n",
+ "#-- ellipse is recommended, unless you have a single-point MA project, then use BOX!\n",
+ "\n",
+ "# Optional:\n",
+ "# imagine you tracked a point that is not useful for assembly,\n",
+ "# like a tail tip that is far from the body, consider dropping it for this step (it's still used later)!\n",
+ "# To drop it, uncomment the next line TWO lines and add your parts(s):\n",
+ "\n",
+ "# bodypart= 'Tail_end'\n",
+ "# deeplabcut.convert_detections2tracklets(path_config_file, videofile_path, videotype=VideoType, shuffle=shuffle, overwrite=True, ignore_bodyparts=[bodypart])\n",
+ "\n",
+ "# OR don't drop, just click RUN:\n",
+ "deeplabcut.convert_detections2tracklets(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " videotype=video_type,\n",
+ " shuffle=shuffle,\n",
+ " track_method=track_type,\n",
+ " destfolder=destfolder,\n",
+ " overwrite=True,\n",
+ ")\n",
+ "\n",
+ "deeplabcut.stitch_tracklets(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " shuffle=shuffle,\n",
+ " track_method=track_type,\n",
+ " n_tracks=num_animals,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -437,11 +443,14 @@
},
"outputs": [],
"source": [
- "deeplabcut.filterpredictions(path_config_file, \n",
- " videofile_path, \n",
- " shuffle=shuffle,\n",
- " videotype=VideoType, \n",
- " track_method = tracktype)"
+ "deeplabcut.filterpredictions(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " shuffle=shuffle,\n",
+ " videotype=video_type,\n",
+ " track_method=track_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -461,7 +470,14 @@
},
"outputs": [],
"source": [
- "deeplabcut.plot_trajectories(path_config_file, videofile_path, videotype=VideoType, shuffle=shuffle, track_method=tracktype)"
+ "deeplabcut.plot_trajectories(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " videotype=video_type,\n",
+ " shuffle=shuffle,\n",
+ " track_method=track_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
},
{
@@ -491,13 +507,17 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,\n",
- " videofile_path, \n",
- " shuffle=shuffle, \n",
- " color_by=\"individual\",\n",
- " videotype=VideoType, \n",
- " save_frames=False,\n",
- " filtered=True)"
+ "deeplabcut.create_labeled_video(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " shuffle=shuffle,\n",
+ " color_by=\"individual\",\n",
+ " videotype=video_type,\n",
+ " save_frames=False,\n",
+ " filtered=True,\n",
+ " track_method=track_type,\n",
+ " destfolder=destfolder,\n",
+ ")"
]
}
],
diff --git a/examples/COLAB/COLAB_transformer_reID.ipynb b/examples/COLAB/COLAB_transformer_reID.ipynb
index 814852a296..e84ce61c22 100644
--- a/examples/COLAB/COLAB_transformer_reID.ipynb
+++ b/examples/COLAB/COLAB_transformer_reID.ipynb
@@ -16,7 +16,7 @@
"id": "TGChzLdc-lUJ"
},
"source": [
- "# DeepLabCut 2.2 Toolbox Demo on how to use our Pose Transformer for unsupervised identity tracking of animals\n",
+ "# DeepLabCut 3.0 Toolbox Demo: How to use our Pose Transformer for unsupervised identity tracking of animals\n",
"\n",
"\n",
"https://github.com/DeepLabCut/DeepLabCut\n",
@@ -43,40 +43,8 @@
"metadata": {},
"outputs": [],
"source": [
- "# Install TensorFlow, tensorpack and tf_slim versions compatible with DeepLabCut\n",
- "!pip install \"tensorflow==2.12.1\" \"tensorpack>=0.11\" \"tf_slim>=1.1.0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Downgrade PyTorch to a version using CUDA 11.8 and cudnn 8\n",
- "# This will also install the required CUDA libraries, for both PyTorch and TensorFlow\n",
- "!pip install torch==2.3.1 torchvision --index-url https://download.pytorch.org/whl/cu118"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# Install the latest version of DeepLabCut\n",
- "!pip install \"git+https://github.com/DeepLabCut/DeepLabCut.git\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# As described in https://www.tensorflow.org/install/pip#step-by-step_instructions, \n",
- "# create symbolic links to NVIDIA shared libraries:\n",
- "!ln -svf /usr/local/lib/python3.11/dist-packages/nvidia/*/lib/*.so* /usr/local/lib/python3.11/dist-packages/tensorflow"
+ "# Install the latest DeepLabCut version:\n",
+ "!pip install --pre deeplabcut"
]
},
{
@@ -103,17 +71,17 @@
"from io import BytesIO\n",
"from zipfile import ZipFile\n",
"\n",
- "url_record = 'https://zenodo.org/api/records/7883589'\n",
+ "url_record = \"https://zenodo.org/api/records/7883589\"\n",
"response = requests.get(url_record)\n",
"if response.status_code == 200:\n",
- " file = response.json()['files'][0]\n",
- " title = file['key']\n",
+ " file = response.json()[\"files\"][0]\n",
+ " title = file[\"key\"]\n",
" print(f\"Downloading {title}...\")\n",
- " with requests.get(file['links']['self'], stream=True) as r:\n",
+ " with requests.get(file[\"links\"][\"self\"], stream=True) as r:\n",
" with ZipFile(BytesIO(r.content)) as zf:\n",
- " zf.extractall(path='/content')\n",
+ " zf.extractall(path=\"/content\")\n",
"else:\n",
- " raise ValueError(f'The URL {url_record} could not be reached.')"
+ " raise ValueError(f\"The URL {url_record} could not be reached.\")"
]
},
{
@@ -124,7 +92,7 @@
"source": [
"## Analyze a novel 3 mouse video with our maDLC DLCRNet, pretrained on 3 mice data \n",
"\n",
- "###in one step, since auto_track=True you extract detections and association costs, create tracklets, & stitch them. We can use this to compare to the transformer-guided tracking below.\n"
+ "In one step, since `auto_track=True` you extract detections and association costs, create tracklets, & stitch them. We can use this to compare to the transformer-guided tracking below.\n"
]
},
{
@@ -135,14 +103,14 @@
},
"outputs": [],
"source": [
- "import deeplabcut as dlc\n",
+ "import deeplabcut\n",
"import os\n",
"\n",
"project_path = \"/content/demo-me-2021-07-14\"\n",
"config_path = os.path.join(project_path, \"config.yaml\")\n",
"video = os.path.join(project_path, \"videos\", \"videocompressed1.mp4\")\n",
"\n",
- "dlc.analyze_videos(config_path,[video], shuffle=0, videotype=\"mp4\",auto_track=True)"
+ "deeplabcut.analyze_videos(config_path,[video], shuffle=0, videotype=\"mp4\", auto_track=True)"
]
},
{
@@ -172,16 +140,11 @@
"outputs": [],
"source": [
"#Filter the predictions to remove small jitter, if desired:\n",
- "dlc.filterpredictions(config_path,\n",
- " [video],\n",
- " shuffle=0,\n",
- " videotype='mp4',\n",
- " )\n",
- "\n",
- "dlc.create_labeled_video(\n",
+ "deeplabcut.filterpredictions(config_path, [video], shuffle=0, videotype=\"mp4\")\n",
+ "deeplabcut.create_labeled_video(\n",
" config_path,\n",
" [video],\n",
- " videotype='mp4',\n",
+ " videotype=\"mp4\",\n",
" shuffle=0,\n",
" color_by=\"individual\",\n",
" keypoints_only=False,\n",
@@ -218,9 +181,7 @@
"id": "7w9BDIA7BB_i"
},
"outputs": [],
- "source": [
- "dlc.plot_trajectories(config_path, [video], shuffle=0,videotype='mp4')"
- ]
+ "source": "deeplabcut.plot_trajectories(config_path, [video], shuffle=0,videotype=\"mp4\")"
},
{
"cell_type": "markdown",
@@ -241,10 +202,14 @@
},
"outputs": [],
"source": [
- "dlc.transformer_reID(config_path, [video],\n",
- " shuffle=0, videotype='mp4',\n",
- " track_method='ellipse',n_triplets=100\n",
- " )"
+ "deeplabcut.transformer_reID(\n",
+ " config_path,\n",
+ " [video],\n",
+ " shuffle=0,\n",
+ " videotype=\"mp4\",\n",
+ " track_method=\"ellipse\",\n",
+ " n_triplets=100,\n",
+ ")"
]
},
{
@@ -264,10 +229,13 @@
},
"outputs": [],
"source": [
- "dlc.plot_trajectories(config_path, [video],\n",
- " shuffle=0,videotype='mp4',\n",
- " track_method=\"transformer\"\n",
- " )"
+ "deeplabcut.plot_trajectories(\n",
+ " config_path,\n",
+ " [video],\n",
+ " shuffle=0,\n",
+ " videotype=\"mp4\",\n",
+ " track_method=\"transformer\",\n",
+ ")"
]
},
{
@@ -278,10 +246,10 @@
},
"outputs": [],
"source": [
- "dlc.create_labeled_video(\n",
+ "deeplabcut.create_labeled_video(\n",
" config_path,\n",
" [video],\n",
- " videotype='mp4',\n",
+ " videotype=\"mp4\",\n",
" shuffle=0,\n",
" color_by=\"individual\",\n",
" keypoints_only=False,\n",
diff --git a/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb b/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb
index 92d7eb9ff9..42242b6664 100644
--- a/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb
+++ b/examples/JUPYTER/Demo_3D_DeepLabCut.ipynb
@@ -275,9 +275,9 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python [conda env:DEEPLABCUT_newGUI] *",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "conda-env-DEEPLABCUT_newGUI-py"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -289,7 +289,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.13"
+ "version": "3.11.11"
}
},
"nbformat": 4,
diff --git a/examples/JUPYTER/Demo_labeledexample_MouseReaching.ipynb b/examples/JUPYTER/Demo_labeledexample_MouseReaching.ipynb
index cc45d9a2a7..312e0a1505 100644
--- a/examples/JUPYTER/Demo_labeledexample_MouseReaching.ipynb
+++ b/examples/JUPYTER/Demo_labeledexample_MouseReaching.ipynb
@@ -8,7 +8,11 @@
},
"source": [
"# DeepLabCut Toolbox - DEMO (mouse reaching)\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "\n",
+ "Some resources that can be useful:\n",
+ "\n",
+ "- [github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Single Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html)\n",
"\n",
"#### The notebook accompanies the following user-guide:\n",
"\n",
@@ -35,7 +39,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Import the toolbox:"
+ "## Import the Toolbox and Required Libraries"
]
},
{
@@ -48,6 +52,8 @@
},
"outputs": [],
"source": [
+ "from pathlib import Path\n",
+ "\n",
"import deeplabcut"
]
},
@@ -64,12 +70,17 @@
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
- "# Note that parameters of this project can be seen at: *Reaching-Mackenzie-2018-08-30/config.yaml*\n",
- "from pathlib import Path\n",
+ "# Create a variable to set the config.yaml file path:\n",
+ "# If this path does not point to the project from the URL below,\n",
+ "# edit it to make sure it does:\n",
+ "# https://github.com/DeepLabCut/DeepLabCut/tree/main/examples/Reaching-Mackenzie-2018-08-30\n",
+ "# \n",
+ "# Example - Linux/OSX\n",
+ "# path_config_file = \"/Users/john/DeepLabCut/examples/Reaching-Mackenzie-2018-08-30/config.yaml\"\n",
+ "# Example - Windows\n",
+ "# path_config_file = r\"C:\\DeepLabCut\\examples\\Reaching-Mackenzie-2018-08-30\\config.yaml\"\n",
"\n",
- "#create a variable to set the config.yaml file path:\n",
- "path_config_file = os.path.join(os.getcwd(),'Reaching-Mackenzie-2018-08-30/config.yaml')\n",
+ "path_config_file = str(Path.cwd() / \"Reaching-Mackenzie-2018-08-30\" / \"config.yaml\")\n",
"print(path_config_file)"
]
},
@@ -99,8 +110,8 @@
},
"outputs": [],
"source": [
- "#let's load some demo data, and create a training set \n",
- "#(note, this function is not used when you create your own project):\n",
+ "# Let's load some demo data, and create a training set \n",
+ "# (note, this function is not used when you create your own project):\n",
"\n",
"deeplabcut.load_demo_data(path_config_file)"
]
@@ -115,7 +126,7 @@
},
"outputs": [],
"source": [
- "#Perhaps plot the labels to see how the frames were annotated:\n",
+ "# Perhaps plot the labels to see how the frames were annotated:\n",
"\n",
"deeplabcut.check_labels(path_config_file)"
]
@@ -128,11 +139,12 @@
},
"source": [
"## Start training of Feature Detectors\n",
- "This function trains the network for a specific shuffle of the training dataset. **The user can set various parameters in /Reaching-Mackenzie-2018-08-30/dlc-models/ReachingAug30-trainset95shuffle1/iteration-0/train/pose_cfg.yaml.**\n",
"\n",
- "Training can be stopped at any time. Note that the weights are only stored every 'save_iters' steps. For this demo the it is advisable to store & display the progress very often (i.e. display every 20, save every 100). In practice this is inefficient (in reality, you will train until ~200K, so we save every 50K).\n",
+ "This function trains the network for a specific shuffle of the training dataset. **The user can set various parameters in `.../Reaching-Mackenzie-2018-08-30/dlc-models-pytorch/iteration-0/ReachingAug30-trainset95shuffle1/train/pytorch_config.yaml`**. For more information about the variables that can be set, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config.html)!\n",
"\n",
- "**We recommend just training for 10-20 min, as you aren't running this demo to use DLC, just to work through the steps. In total, this demo should take you LESS THAN 1 HOUR!**"
+ "Training can be stopped at any time. Note that the weights are only stored every 'save_epochs' steps. For this demo the it is advisable to store & display the progress very often (i.e. display every 20, save every 2). In practice this is inefficient (in reality, you will train until ~200, so we save every 10).\n",
+ "\n",
+ "**We recommend just training for 15-20 min, as you aren't running this demo to use DLC, just to work through the steps. In total, this demo should take you LESS THAN 1 HOUR!**"
]
},
{
@@ -142,24 +154,26 @@
"colab": {},
"colab_type": "code",
"id": "jg96O2acywnW",
- "scrolled": false
+ "scrolled": true
},
"outputs": [],
"source": [
- "deeplabcut.train_network(path_config_file, shuffle=1, saveiters=300, displayiters=10)\n",
- "#notice the variables \"saveiters\" and \"dsiplayiters\" that can be set in the function\n",
+ "# notice the variables \"save_epochs\" and \"displayiters\" that can be set in the function\n",
+ "deeplabcut.train_network(path_config_file, shuffle=1, save_epochs=2, displayiters=10)\n",
"\n",
- "#you just need to run this until you get at least 1 snapshot, which is set by: \"save_iters\" \n",
- "#(so in this case you could stop after 500!) How do I stop? Click the STOP button!\n",
- "# To train until ~2,000 iterations on a CPU should be ~30 min"
+ "# you just need to run this until you get at least 1 snapshot, which is set by: \"save_epochs\" \n",
+ "# (so in this case you could stop after 2 epochs!) How do I stop? Click the STOP button!\n",
+ "\n",
+ "# To train until ~50 epochs on a CPU should be ~15 min\n",
+ "# Every 10 epochs, your model will be evaluated. You can keep an eye on model performance\n",
+ "# while the model is being trained."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "*Note, that if it reaches the end (default 1M) or you stop it (by \"stop\" or by CTRL+C), \n",
- "you will see an keyboard interrupt \"error\", but it is not a real error, i.e. you can ignore this.*"
+ "*Note, that if you stop it (by \"stop\" or by CTRL+C), you will see an keyboard interrupt \"error\", but it is not a real error, i.e. you can ignore this.*"
]
},
{
@@ -171,7 +185,7 @@
"source": [
"## Evaluate the trained network\n",
"\n",
- "This function evaluates a trained model for a specific shuffle/shuffles at a particular training state (snapshot) or on all the states. The network is evaluated on the data set (images) and stores the results as .csv file in a subdirectory under **evaluation-results**.\n",
+ "This function evaluates a trained model for a specific shuffle/shuffles at a particular training state (snapshot) or on all the states. The network is evaluated on the data set (images) and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**.\n",
"\n",
"You can change various parameters in the ```config.yaml``` file of this project. For the evaluation one can change pcutoff. This cutoff also influences how likely estimated positions need to be so that they are shown in the plots."
]
@@ -187,14 +201,14 @@
},
"outputs": [],
"source": [
- "deeplabcut.evaluate_network(path_config_file,plotting=True)"
+ "deeplabcut.evaluate_network(path_config_file, plotting=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "**NOTE: depending on your set up sometimes you get some \"matplotlib errors, but these are not important**\n",
+ "**NOTE: depending on your setup sometimes you get some \"matplotlib errors, but these are not important**\n",
"\n",
"Now you can go check out the images. Given the limited data input and it took ~20 mins to test this out, it is not meant to track well, so don't be alarmed. This is just to get you familiar with the workflow... "
]
@@ -223,12 +237,12 @@
"outputs": [],
"source": [
"# Set the video path:\n",
- "#The video can be the one you trained with and new videos that look similar, i.e. same experiments, etc.\n",
+ "# The video can be the one you trained with and new videos that look similar, i.e. same experiments, etc.\n",
"# You can add individual videos, OR just a folder - it will skip videos that are already analyzed once.\n",
"\n",
- "#i.e you can run 'reachingvideo1' and/or 'MovieS2_Perturbation_noLaser_compressed'\n",
+ "# i.e. you can run 'reachingvideo1' and/or 'MovieS2_Perturbation_noLaser_compressed'\n",
"\n",
- "videofile_path = os.path.join(os.getcwd(),'Reaching-Mackenzie-2018-08-30/videos/reachingvideo1.avi') "
+ "videofile_path = str(Path(path_config_file).parent / \"videos\" / \"reachingvideo1.avi\")"
]
},
{
@@ -243,8 +257,9 @@
"outputs": [],
"source": [
"print(\"Start Analyzing the video!\")\n",
- "deeplabcut.analyze_videos(path_config_file,[videofile_path])\n",
- "# this video takes ~ 8 min to analyze with a CPU"
+ "\n",
+ "deeplabcut.analyze_videos(path_config_file, [videofile_path])\n",
+ "# this video takes ~ 1 min to analyze with a CPU"
]
},
{
@@ -279,7 +294,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,[videofile_path], draw_skeleton=True)"
+ "deeplabcut.create_labeled_video(path_config_file, [videofile_path], draw_skeleton=True)"
]
},
{
@@ -305,9 +320,9 @@
"outputs": [],
"source": [
"%matplotlib notebook\n",
- "deeplabcut.plot_trajectories(path_config_file,[videofile_path],showfigures=True)\n",
+ "deeplabcut.plot_trajectories(path_config_file, [videofile_path], showfigures=True)\n",
"\n",
- "#These plots can are interactive and can be customized (see https://matplotlib.org/)"
+ "# These plots are interactive and can be customized (see https://matplotlib.org/)"
]
},
{
@@ -339,11 +354,16 @@
"colab": {},
"colab_type": "code",
"id": "RJGiDKuUywoC",
- "scrolled": false
+ "scrolled": true
},
"outputs": [],
"source": [
- "deeplabcut.extract_outlier_frames(path_config_file,videofile_path,outlieralgorithm='uncertain',p_bound=.2)"
+ "deeplabcut.extract_outlier_frames(\n",
+ " path_config_file,\n",
+ " videofile_path,\n",
+ " outlieralgorithm=\"uncertain\",\n",
+ " p_bound=0.2,\n",
+ ")"
]
},
{
@@ -365,7 +385,9 @@
"source": [
"## Manually correct labels\n",
"\n",
- "This step allows the user to correct the labels in the extracted frames. Navigate to the folder with the videos and use the GUI as described in the protocol to update the labels."
+ "This step allows the user to correct the labels in the extracted frames. Navigate to the folder with the videos and use the GUI as described in the protocol to update the labels.\n",
+ "\n",
+ "For documentation regarding the GUI, [look at the docs for `napari-deeplabcut`](https://github.com/DeepLabCut/napari-deeplabcut/tree/main) - and specifically _\"3. Refining labels – the image folder contains a machinelabels-iter<#>.h5 file.\"_!"
]
},
{
@@ -379,9 +401,6 @@
},
"outputs": [],
"source": [
- "#GUI pops up! \n",
- "#sometimes you need to restart the kernel for the GUI to launch.\n",
- "%gui wx\n",
"deeplabcut.refine_labels(path_config_file)"
]
},
@@ -436,7 +455,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_training_dataset(path_config_file)"
+ "deeplabcut.create_training_dataset(path_config_file, engine=deeplabcut.Engine.PYTORCH)"
]
},
{
@@ -446,7 +465,7 @@
"id": "8fhL6nG2ywoW"
},
"source": [
- "Now one can train the network again... (with the expanded data set)"
+ "Now one can train the network again... (with the expanded data set). We can continue training from the snapshot we already have by using the `snapshot_path` argument - instead of training the model from scratch, it will load the weights we already have and fine-tune them!"
]
},
{
@@ -459,8 +478,31 @@
},
"outputs": [],
"source": [
- "deeplabcut.train_network(path_config_file)"
+ "snapshot_path = ( # Edit me if needed! Select the path to the snapshot to continue training from!\n",
+ " Path(path_config_file).parent / \n",
+ " \"dlc-models-pytorch\" / \n",
+ " \"iteration-0\" / \n",
+ " \"ReachingAug30-trainset95shuffle1\" / \n",
+ " \"train\" / \n",
+ " \"snapshot-best-080.pt\"\n",
+ ")\n",
+ "\n",
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=2,\n",
+ " displayiters=10,\n",
+ " batch_size=8,\n",
+ " snapshot_path=snapshot_path,\n",
+ ")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -471,9 +513,9 @@
"version": "0.3.2"
},
"kernelspec": {
- "display_name": "Python [conda env:DLC2]",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "conda-env-DLC2-py"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -485,7 +527,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.9"
+ "version": "3.11.11"
},
"varInspector": {
"cols": {
diff --git a/examples/JUPYTER/Demo_labeledexample_Openfield.ipynb b/examples/JUPYTER/Demo_labeledexample_Openfield.ipynb
index 3b75f5b60a..d4458f1111 100644
--- a/examples/JUPYTER/Demo_labeledexample_Openfield.ipynb
+++ b/examples/JUPYTER/Demo_labeledexample_Openfield.ipynb
@@ -8,7 +8,11 @@
},
"source": [
"# DeepLabCut Toolbox - Open-Field DEMO\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "\n",
+ "Some resources that can be useful:\n",
+ "\n",
+ "- [github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Single Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html)\n",
"\n",
"#### The notebook accompanies the following user-guide:\n",
"\n",
@@ -34,7 +38,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@@ -43,65 +47,50 @@
"outputs": [],
"source": [
"# Importing the toolbox (takes several seconds)\n",
+ "from pathlib import Path\n",
+ "\n",
"import deeplabcut"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "WOEHc0MeywnJ"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Loaded, now creating training data...\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30 already exists!\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/labeled-data/short_mp3/CollectedData_Pranav.h5 not found (perhaps not annotated)\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1 already exists!\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1//train already exists!\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1//test already exists!\n",
- "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "# Loading example data set:\n",
- "import os\n",
+ "# Create a variable to set the config.yaml file path:\n",
+ "# If this path does not point to the project from the URL below,\n",
+ "# edit it to make sure it does:\n",
+ "# https://github.com/DeepLabCut/DeepLabCut/tree/main/examples/openfield-Pranav-2018-10-30\n",
+ "# \n",
+ "# Example - Linux/OSX\n",
+ "# path_config_file = \"/Users/john/DeepLabCut/examples/openfield-Pranav-2018-10-30/config.yaml\"\n",
+ "# Example - Windows\n",
+ "# path_config_file = r\"C:\\DeepLabCut\\examples\\openfield-Pranav-2018-10-30\\config.yaml\"\n",
+ "#\n",
"# Note that parameters of this project can be seen at: *openfield-Pranav-2018-10-30/config.yaml*\n",
- "from pathlib import Path\n",
- "path_config_file = os.path.join(os.getcwd(),'openfield-Pranav-2018-10-30/config.yaml')\n",
+ "\n",
+ "path_config_file = str(Path.cwd() / \"openfield-Pranav-2018-10-30\" / \"config.yaml\")\n",
"deeplabcut.load_demo_data(path_config_file)"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "ROlflqQLywnP"
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Creating images with labels by Pranav.\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1_labeled already exists!\n",
- "They are stored in the following folder: /home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1_labeled.\n",
- "Attention: /home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/labeled-data/short_mp3 does not appear to have labeled data!\n",
- "If all the labels are ok, then use the function 'create_training_dataset' to create the training dataset!\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "#[OPTIONAL] Perhaps plot the labels to see how the frames were annotated:\n",
- "#(note, this project was created in Linux, so you might have an error in Windows, but this is an optional step)\n",
+ "# [OPTIONAL] Perhaps plot the labels to see how the frames were annotated:\n",
+ "# (note, this project was created in Linux, so you might have an error in Windows, but this is an optional step)\n",
+ "\n",
"deeplabcut.check_labels(path_config_file)"
]
},
@@ -113,141 +102,30 @@
},
"source": [
"## Start training of Feature Detectors\n",
- "This function trains the network for a specific shuffle of the training dataset. The user can set various parameters in */openfield-Pranav-2018-10-30/dlc-models/.../pose_cfg.yaml*. \n",
"\n",
- "Training can be stopped at any time. Note that the weights are only stored every 'save_iters' steps. For this demo the state it is advisable to store & display the progress very often. In practice this is inefficient. "
+ "This function trains the network for a specific shuffle of the training dataset. The user can set various parameters in `/openfield-Pranav-2018-10-30/dlc-models-pytorch/.../pytorch_config.yaml`. For more information about the variables that can be set, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config.html)!\n",
+ "\n",
+ "Training can be stopped at any time. Note that the weights are only stored every 'save_epochs' epochs. For this demo the state it is advisable to store & display the progress very often. In practice this is inefficient. You should see the model start converging around 50 to 60 epochs; you can continue training it longer to improve performance."
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jg96O2acywnW",
- "scrolled": false
+ "scrolled": true
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Config:\n",
- "{'all_joints': [[0], [1], [2], [3]],\n",
- " 'all_joints_names': ['snout', 'leftear', 'rightear', 'tailbase'],\n",
- " 'batch_size': 1,\n",
- " 'bottomheight': 400,\n",
- " 'crop': True,\n",
- " 'crop_pad': 0,\n",
- " 'cropratio': 0.4,\n",
- " 'dataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/openfield_Pranav95shuffle1.mat',\n",
- " 'dataset_type': 'default',\n",
- " 'display_iters': 1000,\n",
- " 'fg_fraction': 0.25,\n",
- " 'global_scale': 0.8,\n",
- " 'init_weights': '/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt',\n",
- " 'intermediate_supervision': False,\n",
- " 'intermediate_supervision_layer': 12,\n",
- " 'leftwidth': 400,\n",
- " 'location_refinement': True,\n",
- " 'locref_huber_loss': True,\n",
- " 'locref_loss_weight': 0.05,\n",
- " 'locref_stdev': 7.2801,\n",
- " 'log_dir': 'log',\n",
- " 'max_input_size': 1500,\n",
- " 'mean_pixel': [123.68, 116.779, 103.939],\n",
- " 'metadataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/Documentation_data-openfield_95shuffle1.pickle',\n",
- " 'min_input_size': 64,\n",
- " 'minsize': 100,\n",
- " 'mirror': False,\n",
- " 'multi_step': [[0.005, 10000],\n",
- " [0.02, 430000],\n",
- " [0.002, 730000],\n",
- " [0.001, 1030000]],\n",
- " 'net_type': 'resnet_50',\n",
- " 'num_joints': 4,\n",
- " 'optimizer': 'sgd',\n",
- " 'pos_dist_thresh': 17,\n",
- " 'project_path': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30',\n",
- " 'regularize': False,\n",
- " 'rightwidth': 400,\n",
- " 'save_iters': 50000,\n",
- " 'scale_jitter_lo': 0.5,\n",
- " 'scale_jitter_up': 1.25,\n",
- " 'scoremap_dir': 'test',\n",
- " 'shuffle': True,\n",
- " 'snapshot_prefix': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1/train/snapshot',\n",
- " 'stride': 8.0,\n",
- " 'topheight': 400,\n",
- " 'use_gt_segm': False,\n",
- " 'video': False,\n",
- " 'video_batch': False,\n",
- " 'weigh_negatives': False,\n",
- " 'weigh_only_present_joints': False,\n",
- " 'weigh_part_predictions': False,\n",
- " 'weight_decay': 0.0001}\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "INFO:tensorflow:Restoring parameters from /home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt\n",
- "Display_iters overwritten as 10\n",
- "Save_iters overwritten as 100\n",
- "Training parameter:\n",
- "{'stride': 8.0, 'weigh_part_predictions': False, 'weigh_negatives': False, 'fg_fraction': 0.25, 'weigh_only_present_joints': False, 'mean_pixel': [123.68, 116.779, 103.939], 'shuffle': True, 'snapshot_prefix': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1/train/snapshot', 'log_dir': 'log', 'global_scale': 0.8, 'location_refinement': True, 'locref_stdev': 7.2801, 'locref_loss_weight': 0.05, 'locref_huber_loss': True, 'optimizer': 'sgd', 'intermediate_supervision': False, 'intermediate_supervision_layer': 12, 'regularize': False, 'weight_decay': 0.0001, 'mirror': False, 'crop_pad': 0, 'scoremap_dir': 'test', 'dataset_type': 'default', 'use_gt_segm': False, 'batch_size': 1, 'video': False, 'video_batch': False, 'crop': True, 'cropratio': 0.4, 'minsize': 100, 'leftwidth': 400, 'rightwidth': 400, 'topheight': 400, 'bottomheight': 400, 'all_joints': [[0], [1], [2], [3]], 'all_joints_names': ['snout', 'leftear', 'rightear', 'tailbase'], 'dataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/openfield_Pranav95shuffle1.mat', 'display_iters': 1000, 'init_weights': '/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt', 'max_input_size': 1500, 'metadataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/Documentation_data-openfield_95shuffle1.pickle', 'min_input_size': 64, 'multi_step': [[0.005, 10000], [0.02, 430000], [0.002, 730000], [0.001, 1030000]], 'net_type': 'resnet_50', 'num_joints': 4, 'pos_dist_thresh': 17, 'project_path': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30', 'save_iters': 50000, 'scale_jitter_lo': 0.5, 'scale_jitter_up': 1.25}\n",
- "Starting training....\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "iteration: 10 loss: 0.3308 lr: 0.005\n",
- "iteration: 20 loss: 0.0563 lr: 0.005\n",
- "iteration: 30 loss: 0.0417 lr: 0.005\n",
- "iteration: 40 loss: 0.0362 lr: 0.005\n",
- "iteration: 50 loss: 0.0407 lr: 0.005\n",
- "iteration: 60 loss: 0.0461 lr: 0.005\n",
- "iteration: 70 loss: 0.0385 lr: 0.005\n",
- "iteration: 80 loss: 0.0345 lr: 0.005\n",
- "iteration: 90 loss: 0.0314 lr: 0.005\n",
- "iteration: 100 loss: 0.0428 lr: 0.005\n",
- "iteration: 110 loss: 0.0262 lr: 0.005\n",
- "iteration: 120 loss: 0.0255 lr: 0.005\n",
- "iteration: 130 loss: 0.0275 lr: 0.005\n",
- "iteration: 140 loss: 0.0251 lr: 0.005\n",
- "iteration: 150 loss: 0.0221 lr: 0.005\n",
- "iteration: 160 loss: 0.0209 lr: 0.005\n",
- "iteration: 170 loss: 0.0297 lr: 0.005\n",
- "iteration: 180 loss: 0.0325 lr: 0.005\n",
- "iteration: 190 loss: 0.0242 lr: 0.005\n"
- ]
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdeeplabcut\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_network\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath_config_file\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshuffle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdisplayiters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaveiters\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/training.py\u001b[0m in \u001b[0;36mtrain_network\u001b[0;34m(config, shuffle, trainingsetindex, gputouse, max_snapshots_to_keep, autotune, displayiters, saveiters, maxiters)\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposeconfigfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdisplayiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msaveiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmax_to_keep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_snapshots_to_keep\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#pass on path and file name for pose_cfg.yaml!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mBaseException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 89\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 90\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstart_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/training.py\u001b[0m in \u001b[0;36mtrain_network\u001b[0;34m(config, shuffle, trainingsetindex, gputouse, max_snapshots_to_keep, autotune, displayiters, saveiters, maxiters)\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mposeconfigfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdisplayiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0msaveiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxiters\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmax_to_keep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_snapshots_to_keep\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#pass on path and file name for pose_cfg.yaml!\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mBaseException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(config_yaml, displayiters, saveiters, maxiters, max_to_keep)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mcurrent_lr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr_gen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_lr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m [_, loss_val, summary] = sess.run([train_op, total_loss, merged_summaries],\n\u001b[0;32m--> 142\u001b[0;31m feed_dict={learning_rate: current_lr})\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0mcum_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss_val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0mtrain_writer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[0;32m--> 900\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1133\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1134\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[0;32m-> 1135\u001b[0;31m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[1;32m 1136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1137\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_run\u001b[0;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[1;32m 1314\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1315\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[0;32m-> 1316\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1317\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1318\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_do_call\u001b[0;34m(self, fn, *args)\u001b[0m\n\u001b[1;32m 1320\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1321\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1322\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1323\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[0mmessage\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[0;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[1;32m 1305\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1306\u001b[0m return self._call_tf_sessionrun(\n\u001b[0;32m-> 1307\u001b[0;31m options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[1;32m 1308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[0;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[1;32m 1407\u001b[0m return tf_session.TF_SessionRun_wrapper(\n\u001b[1;32m 1408\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1409\u001b[0;31m run_metadata)\n\u001b[0m\u001b[1;32m 1410\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1411\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraise_exception_on_not_ok_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
+ "outputs": [],
"source": [
- "deeplabcut.train_network(path_config_file, shuffle=1, displayiters=10, saveiters=100)"
+ "# notice the variables \"save_epochs\" and \"displayiters\" that can be set in the function\n",
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=2,\n",
+ " displayiters=5,\n",
+ ")"
]
},
{
@@ -267,98 +145,23 @@
"source": [
"## Evaluate a trained network\n",
"\n",
- "This function evaluates a trained model for a specific shuffle/shuffles at a particular training state (snapshot) or on all the states. The network is evaluated on the data set (images) and stores the results as .csv file in a subdirectory under **evaluation-results**.\n",
+ "This function evaluates a trained model for a specific shuffle/shuffles at a particular training state (snapshot) or on all the states. The network is evaluated on the data set (images) and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**.\n",
"\n",
"You can change various parameters in the ```config.yaml``` file of this project. For evaluation all the model descriptors (Task, TrainingFraction, Date etc.) are important. For the evaluation one can change pcutoff. This cutoff also influences how likely estimated positions need to be so that they are shown in the plots. One can furthermore, change the colormap and dotsize for those graphs."
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kuprPKDdywne",
"scrolled": false
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Config:\n",
- "{'all_joints': [[0], [1], [2], [3]],\n",
- " 'all_joints_names': ['snout', 'leftear', 'rightear', 'tailbase'],\n",
- " 'batch_size': 1,\n",
- " 'bottomheight': 400,\n",
- " 'crop': True,\n",
- " 'crop_pad': 0,\n",
- " 'cropratio': 0.4,\n",
- " 'dataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/openfield_Pranav95shuffle1.mat',\n",
- " 'dataset_type': 'default',\n",
- " 'display_iters': 1000,\n",
- " 'fg_fraction': 0.25,\n",
- " 'global_scale': 0.8,\n",
- " 'init_weights': '/home/mackenzie/anaconda3/envs/DLC2/lib/python3.6/site-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt',\n",
- " 'intermediate_supervision': False,\n",
- " 'intermediate_supervision_layer': 12,\n",
- " 'leftwidth': 400,\n",
- " 'location_refinement': True,\n",
- " 'locref_huber_loss': True,\n",
- " 'locref_loss_weight': 0.05,\n",
- " 'locref_stdev': 7.2801,\n",
- " 'log_dir': 'log',\n",
- " 'max_input_size': 1500,\n",
- " 'mean_pixel': [123.68, 116.779, 103.939],\n",
- " 'metadataset': 'training-datasets/iteration-0/UnaugmentedDataSet_openfieldOct30/Documentation_data-openfield_95shuffle1.pickle',\n",
- " 'min_input_size': 64,\n",
- " 'minsize': 100,\n",
- " 'mirror': False,\n",
- " 'multi_step': [[0.005, 10000],\n",
- " [0.02, 430000],\n",
- " [0.002, 730000],\n",
- " [0.001, 1030000]],\n",
- " 'net_type': 'resnet_50',\n",
- " 'num_joints': 4,\n",
- " 'optimizer': 'sgd',\n",
- " 'pos_dist_thresh': 17,\n",
- " 'project_path': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30',\n",
- " 'regularize': False,\n",
- " 'rightwidth': 400,\n",
- " 'save_iters': 50000,\n",
- " 'scale_jitter_lo': 0.5,\n",
- " 'scale_jitter_up': 1.25,\n",
- " 'scoremap_dir': 'test',\n",
- " 'shuffle': True,\n",
- " 'snapshot_prefix': '/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/dlc-models/iteration-0/openfieldOct30-trainset95shuffle1/test/snapshot',\n",
- " 'stride': 8.0,\n",
- " 'topheight': 400,\n",
- " 'use_gt_segm': False,\n",
- " 'video': False,\n",
- " 'video_batch': False,\n",
- " 'weigh_negatives': False,\n",
- " 'weigh_only_present_joints': False,\n",
- " 'weigh_part_predictions': False,\n",
- " 'weight_decay': 0.0001}\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/evaluation-results/ already exists!\n",
- "/home/mackenzie/DEEPLABCUT/3D/DeepLabCut2.0-master/examples/openfield-Pranav-2018-10-30/evaluation-results/iteration-0/openfieldOct30-trainset95shuffle1 already exists!\n",
- "Running DeepCut_resnet50_openfieldOct30shuffle1_2400 with # of trainingiterations: 2400\n",
- "This net has already been evaluated!\n",
- "The network is evaluated and the results are stored in the subdirectory 'evaluation_results'.\n",
- "If it generalizes well, choose the best model for prediction and update the config file with the appropriate index for the 'snapshotindex'.\n",
- "Use the function 'analyze_video' to make predictions on new videos.\n",
- "Otherwise consider retraining the network (see DeepLabCut workflow Fig 2)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "deeplabcut.evaluate_network(path_config_file,plotting=False)"
+ "deeplabcut.evaluate_network(path_config_file, plotting=False)"
]
},
{
@@ -393,9 +196,7 @@
},
"outputs": [],
"source": [
- "# Creating video path:\n",
- "import os\n",
- "videofile_path = os.path.join(os.getcwd(),'openfield-Pranav-2018-10-30/videos/m3v1mp4.mp4')"
+ "videofile_path = str(Path(path_config_file).parent / \"videos\" / \"m3v1mp4.mp4\")"
]
},
{
@@ -410,8 +211,8 @@
"outputs": [],
"source": [
"print(\"Start analyzing the video!\")\n",
- "#our demo video on a CPU with take ~30 min to analze! GPU is much faster!\n",
- "deeplabcut.analyze_videos(path_config_file,[videofile_path])"
+ "# our demo video on a CPU with take ~5 min to analze! GPU is much faster!\n",
+ "deeplabcut.analyze_videos(path_config_file, [videofile_path])"
]
},
{
@@ -439,7 +240,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,[videofile_path])"
+ "deeplabcut.create_labeled_video(path_config_file, [videofile_path])"
]
},
{
@@ -465,9 +266,13 @@
"outputs": [],
"source": [
"%matplotlib notebook\n",
- "deeplabcut.plot_trajectories(path_config_file,[videofile_path],showfigures=True)\n",
+ "deeplabcut.plot_trajectories(\n",
+ " path_config_file,\n",
+ " [videofile_path],\n",
+ " showfigures=True,\n",
+ ")\n",
"\n",
- "#These plots can are interactive and can be customized (see https://matplotlib.org/)"
+ "# These plots are interactive and can be customized (see https://matplotlib.org/)"
]
},
{
@@ -493,7 +298,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.extract_outlier_frames(path_config_file,[videofile_path])"
+ "deeplabcut.extract_outlier_frames(path_config_file, [videofile_path])"
]
},
{
@@ -515,7 +320,9 @@
"source": [
"## Manually correct labels\n",
"\n",
- "This step allows the user to correct the labels in the extracted frames. Navigate to the folder corresponding to the video 'm3v1mp4' and use the GUI as described in the protocol to update the labels."
+ "This step allows the user to correct the labels in the extracted frames. Navigate to the folder corresponding to the video 'm3v1mp4' and use the GUI as described in the protocol to update the labels.\n",
+ "\n",
+ "For documentation regarding the GUI, [look at the docs for `napari-deeplabcut`](https://github.com/DeepLabCut/napari-deeplabcut/tree/main) - and specifically _\"3. Refining labels – the image folder contains a machinelabels-iter<#>.h5 file.\"_!"
]
},
{
@@ -528,7 +335,6 @@
},
"outputs": [],
"source": [
- "%gui qt6\n",
"deeplabcut.refine_labels(path_config_file)"
]
},
@@ -542,7 +348,7 @@
},
"outputs": [],
"source": [
- "#Perhaps plot the labels to see how how all the frames are annotated (including the refined ones)\n",
+ "# Perhaps plot the labels to see how how all the frames are annotated (including the refined ones)\n",
"deeplabcut.check_labels(path_config_file)"
]
},
@@ -592,7 +398,7 @@
"id": "8fhL6nG2ywoW"
},
"source": [
- "Now one can train the network again... (with the expanded data set)"
+ "Now one can train the network again... (with the expanded data set). We can continue training from the snapshot we already have by using the `snapshot_path` argument - instead of training the model from scratch, it will load the weights we already have and fine-tune them!"
]
},
{
@@ -605,8 +411,31 @@
},
"outputs": [],
"source": [
- "deeplabcut.train_network(path_config_file, shuffle=1)"
+ "snapshot_path = ( # Edit me if needed! Select the path to the snapshot to continue training from!\n",
+ " Path(path_config_file).parent / \n",
+ " \"dlc-models-pytorch\" / \n",
+ " \"iteration-0\" / \n",
+ " \"openfieldOct30-trainset95shuffle1\" / \n",
+ " \"train\" / \n",
+ " \"snapshot-best-080.pt\"\n",
+ ")\n",
+ "\n",
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=2,\n",
+ " displayiters=10,\n",
+ " batch_size=8,\n",
+ " snapshot_path=snapshot_path,\n",
+ ")"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
@@ -617,9 +446,9 @@
"version": "0.3.2"
},
"kernelspec": {
- "display_name": "Python [conda env:DLC2]",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "conda-env-DLC2-py"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -631,7 +460,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.9"
+ "version": "3.11.11"
},
"varInspector": {
"cols": {
diff --git a/examples/JUPYTER/Demo_napari.ipynb b/examples/JUPYTER/Demo_napari.ipynb
index b9ff9512b9..8a017d5276 100644
--- a/examples/JUPYTER/Demo_napari.ipynb
+++ b/examples/JUPYTER/Demo_napari.ipynb
@@ -54,22 +54,13 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "jqLZhp7EoEI0"
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/mwmathis/opt/anaconda3/envs/DLC2K/lib/python3.8/site-packages/statsmodels/compat/pandas.py:65: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n",
- " from pandas import Int64Index as NumericIndex\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import deeplabcut"
]
@@ -84,15 +75,20 @@
},
"outputs": [],
"source": [
- "task='Reaching' # Enter the name of your experiment Task\n",
- "experimenter='Mackenzie' # Enter the name of the experimenter\n",
- "video=['/Users/mwmathis/Documents/DeepLabCut/examples/Reaching-Mackenzie-2018-08-30/videos/reachingvideo1.avi'] # Enter the paths of your videos OR FOLDER you want to grab frames from.\n",
+ "task = \"Reaching\" # Enter the name of your experiment Task\n",
+ "experimenter = \"Mackenzie\" # Enter the name of the experimenter\n",
+ "video = [\n",
+ " \"/Users/mwmathis/Documents/DeepLabCut/examples/Reaching-Mackenzie-2018-08-30/videos/reachingvideo1.avi\"\n",
+ "] # Enter the paths of your videos OR FOLDER you want to grab frames from.\n",
+ "\n",
+ "path_config_file = deeplabcut.create_new_project(task, experimenter, video, copy_videos=True) \n",
"\n",
- "path_config_file=deeplabcut.create_new_project(task,experimenter,video,copy_videos=True) \n",
+ "# NOTE: The function returns the path, where your project is.\n",
"\n",
- "# NOTE: The function returns the path, where your project is. \n",
- "# You could also enter this manually (e.g. if the project is already created and you want to pick up, where you stopped...)\n",
- "#path_config_file = '/home/Mackenzie/Reaching/config.yaml' # Enter the path of the config file that was just created from the above step (check the folder)"
+ "# You could also enter this manually (e.g. if the project is already created and you \n",
+ "# want to pick up, where you stopped...): Enter the path of the config file that was\n",
+ "# just created from the above step (check the folder)\n",
+ "# path_config_file = \"/home/Mackenzie/Reaching/config.yaml\""
]
},
{
@@ -148,9 +144,9 @@
},
"outputs": [],
"source": [
- "#there are other ways to grab frames, such as uniformly; please see the paper:\n",
+ "# there are other ways to grab frames, such as uniformly; please see the paper:\n",
"\n",
- "#AUTOMATIC:\n",
+ "# AUTOMATIC:\n",
"deeplabcut.extract_frames(path_config_file) "
]
},
@@ -245,7 +241,7 @@
"\n",
"After running this script the training dataset is created and saved in the project directory under the subdirectory **'training-datasets'**\n",
"\n",
- "This function also creates new subdirectories under **dlc-models** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pose_cfg.yaml**. For most all use cases we have seen, the defaults are perfectly fine.\n",
+ "This function also creates new subdirectories under **dlc-models-pytorch** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pytorch_config.yaml**. For most all use cases we have seen, the defaults are perfectly fine.\n",
"\n",
"Now it is the time to start training the network!"
]
@@ -299,7 +295,7 @@
"source": [
"## Start evaluating\n",
"This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images)\n",
- "and stores the results as .csv file in a subdirectory under **evaluation-results**"
+ "and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**"
]
},
{
@@ -338,7 +334,7 @@
},
"outputs": [],
"source": [
- "videofile_path = ['videos/video3.avi','videos/video4.avi'] #Enter a folder OR a list of videos to analyze.\n",
+ "videofile_path = ['videos/video3.avi','videos/video4.avi'] # Enter a folder OR a list of videos to analyze.\n",
"\n",
"deeplabcut.analyze_videos(path_config_file,videofile_path, videotype='.avi')"
]
@@ -534,9 +530,9 @@
"version": "0.3.2"
},
"kernelspec": {
- "display_name": "Python [conda env:DLC2K]",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "conda-env-DLC2K-py"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -548,7 +544,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.12"
+ "version": "3.11.11"
},
"varInspector": {
"cols": {
diff --git a/examples/JUPYTER/Demo_yourowndata.ipynb b/examples/JUPYTER/Demo_yourowndata.ipynb
index 6b0478306f..2893dff709 100644
--- a/examples/JUPYTER/Demo_yourowndata.ipynb
+++ b/examples/JUPYTER/Demo_yourowndata.ipynb
@@ -8,7 +8,12 @@
},
"source": [
"# DeepLabCut Toolbox\n",
- "https://github.com/DeepLabCut/DeepLabCut\n",
+ "\n",
+ "\n",
+ "Some resources that can be useful:\n",
+ "\n",
+ "- [github.com/DeepLabCut/DeepLabCut](https://github.com/DeepLabCut/DeepLabCut)\n",
+ "- [DeepLabCut's Documentation: User Guide for Single Animal projects](https://deeplabcut.github.io/DeepLabCut/docs/standardDeepLabCut_UserGuide.html)\n",
"\n",
"This notebook demonstrates the necessary steps to use DeepLabCut for your own project.\n",
"This shows the most simple code to do so, but many of the functions have additional features, so please check out the overview & the protocol paper!\n",
@@ -52,7 +57,7 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@@ -73,15 +78,24 @@
},
"outputs": [],
"source": [
- "task='Reaching' # Enter the name of your experiment Task\n",
- "experimenter='Mackenzie' # Enter the name of the experimenter\n",
- "video=['videos/video1.avi','videos/video2.avi'] # Enter the paths of your videos OR FOLDER you want to grab frames from.\n",
+ "task = \"Reaching\" # Enter the name of your experiment Task\n",
+ "experimenter = \"Mackenzie\" # Enter the name of the experimenter\n",
+ "video = [\n",
+ " \"videos/video1.avi\",\n",
+ " \"videos/video2.avi\",\n",
+ "] # Enter the paths of your videos OR FOLDER you want to grab frames from.\n",
"\n",
- "path_config_file=deeplabcut.create_new_project(task,experimenter,video,copy_videos=True) \n",
+ "path_config_file = deeplabcut.create_new_project(\n",
+ " task,\n",
+ " experimenter,\n",
+ " video,\n",
+ " copy_videos=True,\n",
+ ")\n",
"\n",
"# NOTE: The function returns the path, where your project is. \n",
"# You could also enter this manually (e.g. if the project is already created and you want to pick up, where you stopped...)\n",
- "#path_config_file = '/home/Mackenzie/Reaching/config.yaml' # Enter the path of the config file that was just created from the above step (check the folder)"
+ "# Enter the path of the config file that was just created from the above step (check the folder):\n",
+ "# path_config_file = \"/home/Mackenzie/Reaching/config.yaml\""
]
},
{
@@ -101,7 +115,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -153,7 +167,9 @@
"source": [
"## Label the extracted frames\n",
"\n",
- "Only videos in the config file can be used to extract the frames. Extracted labels for each video are stored in the project directory under the subdirectory **'labeled-data'**. Each subdirectory is named after the name of the video. The toolbox has a labeling toolbox which could be used for labeling. "
+ "Only videos in the config file can be used to extract the frames. Extracted labels for each video are stored in the project directory under the subdirectory **'labeled-data'**. Each subdirectory is named after the name of the video. The toolbox has a labeling toolbox which could be used for labeling. \n",
+ "\n",
+ "Check out [our `napari-deeplabcut` docs](https://github.com/DeepLabCut/napari-deeplabcut/tree/main) for more information about labelling!"
]
},
{
@@ -198,7 +214,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.check_labels(path_config_file) #this creates a subdirectory with the frames + your labels"
+ "deeplabcut.check_labels(path_config_file) # this creates a subdirectory with the frames + your labels"
]
},
{
@@ -224,7 +240,7 @@
"\n",
"After running this script the training dataset is created and saved in the project directory under the subdirectory **'training-datasets'**\n",
"\n",
- "This function also creates new subdirectories under **dlc-models** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pose_cfg.yaml**. For most all use cases we have seen, the defaults are perfectly fine.\n",
+ "This function also creates new subdirectories under **dlc-models-pytorch** and creates a `pytorch_config.yaml` file, defining the model architecture and containing various parameters used for training the network. For most all use cases we have seen, the defaults are perfectly fine. For more information about the variables that can be set, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config.html)!\n",
"\n",
"Now it is the time to start training the network!"
]
@@ -241,7 +257,8 @@
"outputs": [],
"source": [
"deeplabcut.create_training_dataset(path_config_file)\n",
- "#remember, there are several networks you can pick, the default is resnet-50!"
+ "\n",
+ "# remember, there are several networks you can pick, the default is resnet-50!"
]
},
{
@@ -253,6 +270,8 @@
"source": [
"## Start training:\n",
"\n",
+ "The user can set various parameters in `.../project-name/dlc-models-pytorch/.../pytorch_config.yaml`. For more information about the variables that can be set, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config.html)!\n",
+ "\n",
"This function trains the network for a specific shuffle of the training dataset. "
]
},
@@ -278,7 +297,7 @@
"source": [
"## Start evaluating\n",
"This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images)\n",
- "and stores the results as .csv file in a subdirectory under **evaluation-results**"
+ "and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**"
]
},
{
@@ -317,7 +336,7 @@
},
"outputs": [],
"source": [
- "videofile_path = ['videos/video3.avi','videos/video4.avi'] #Enter a folder OR a list of videos to analyze.\n",
+ "videofile_path = ['videos/video3.avi', 'videos/video4.avi'] # Enter a folder OR a list of videos to analyze.\n",
"\n",
"deeplabcut.analyze_videos(path_config_file,videofile_path, videotype='.avi')"
]
@@ -336,7 +355,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -445,18 +464,38 @@
},
"source": [
"## Create labeled video\n",
+ "\n",
"This function is for visualiztion purpose and can be used to create a video in .mp4 format with labels predicted by the network. This video is saved in the same directory where the original video resides. \n",
"\n",
"THIS HAS MANY FUN OPTIONS! \n",
"\n",
- "``deeplabcut.create_labeled_video(config, videos, videotype='avi', shuffle=1, trainingsetindex=0, filtered=False, save_frames=False, Frames2plot=None, delete=False, displayedbodyparts='all', codec='mp4v', outputframerate=None, destfolder=None, draw_skeleton=False, trailpoints=0, displaycropped=False)``\n",
+ "```python\n",
+ "deeplabcut.create_labeled_video(\n",
+ " config,\n",
+ " videos,\n",
+ " videotype='avi',\n",
+ " shuffle=1,\n",
+ " trainingsetindex=0,\n",
+ " filtered=False,\n",
+ " save_frames=False,\n",
+ " Frames2plot=None,\n",
+ " delete=False,\n",
+ " displayedbodyparts='all',\n",
+ " codec='mp4v',\n",
+ " outputframerate=None,\n",
+ " destfolder=None,\n",
+ " draw_skeleton=False,\n",
+ " trailpoints=0,\n",
+ " displaycropped=False,\n",
+ ")\n",
+ "```\n",
"\n",
"So please check:"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -498,7 +537,7 @@
"outputs": [],
"source": [
"%matplotlib notebook #for making interactive plots.\n",
- "deeplabcut.plot_trajectories(path_config_file,videofile_path)"
+ "deeplabcut.plot_trajectories(path_config_file, videofile_path)"
]
}
],
@@ -510,9 +549,9 @@
"version": "0.3.2"
},
"kernelspec": {
- "display_name": "Python [conda env:DLC2]",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "conda-env-DLC2-py"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -524,7 +563,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.9"
+ "version": "3.11.11"
},
"varInspector": {
"cols": {
diff --git a/examples/JUPYTER/Docker_TrainNetwork_VideoAnalysis.ipynb b/examples/JUPYTER/Docker_TrainNetwork_VideoAnalysis.ipynb
index b562493a97..a8ff12e6a5 100644
--- a/examples/JUPYTER/Docker_TrainNetwork_VideoAnalysis.ipynb
+++ b/examples/JUPYTER/Docker_TrainNetwork_VideoAnalysis.ipynb
@@ -56,25 +56,11 @@
},
"outputs": [],
"source": [
- "import tensorflow as tf\n",
- "tf.__version__"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {},
- "colab_type": "code",
- "id": "Pm_PC1Q8lRrH"
- },
- "outputs": [],
- "source": [
- "#let's make sure we see a GPU:\n",
- "#tf.test.gpu_device_name()\n",
- "#or\n",
- "from tensorflow.python.client import device_lib\n",
- "device_lib.list_local_devices()"
+ "import torch\n",
+ "\n",
+ "# Let's make sure we see a GPU:\n",
+ "print(torch.__version__)\n",
+ "print(torch.cuda.is_available())"
]
},
{
@@ -94,10 +80,11 @@
},
"outputs": [],
"source": [
- "#GUIs don't work on in Docker (or the cloud), so label your data locally on your computer! \n",
- "#This notebook is for you to train and run video analysis!\n",
+ "# GUIs don't work on in Docker (or the cloud), so label your data locally on your computer! \n",
+ "# This notebook is for you to train and run video analysis!\n",
"import os\n",
- "os.environ[\"DLClight\"]=\"True\""
+ "\n",
+ "os.environ[\"DLClight\"] = \"True\""
]
},
{
@@ -112,8 +99,7 @@
"outputs": [],
"source": [
"# now we are ready to train!\n",
- "import deeplabcut\n",
- "deeplabcut.__version__"
+ "import deeplabcut"
]
},
{
@@ -133,7 +119,8 @@
},
"outputs": [],
"source": [
- "path_config_file = '/home/mackenzie/DEEPLABCUT/DeepLabCut2.0/examples/Reaching-Mackenzie-2018-08-30/config.yaml' #change to yours!"
+ "# change to yours!\n",
+ "path_config_file = '/home/mackenzie/DEEPLABCUT/DeepLabCut/examples/Reaching-Mackenzie-2018-08-30/config.yaml'"
]
},
{
@@ -155,11 +142,12 @@
},
"source": [
"## Create a training dataset\n",
- "This function generates the training data information for DeepCut (which requires a mat file) based on the pandas dataframes that hold label information. The user can set the fraction of the training set size (from all labeled image in the hd5 file) in the config.yaml file. While creating the dataset, the user can create multiple shuffles. \n",
+ "\n",
+ "This function generates the training data required for DeepLabCut. The user can set the fraction of the training set size (from all labeled images in the hd5 file) in the `config.yaml` file. While creating the dataset, the user can create multiple shuffles. \n",
"\n",
"After running this script the training dataset is created and saved in the project directory under the subdirectory **'training-datasets'**\n",
"\n",
- "This function also creates new subdirectories under **dlc-models** and appends the project config.yaml file with the correct path to the training and testing pose configuration file. These files hold the parameters for training the network. Such an example file is provided with the toolbox and named as **pose_cfg.yaml**."
+ "This function also creates new subdirectories under **dlc-models-pytorch** and creates a `pytorch_config.yaml` file, defining the model architecture and containing various parameters used for training the network. For most all use cases we have seen, the defaults are perfectly fine. For more information about the variables that can be set, check out the [docs](https://deeplabcut.github.io/DeepLabCut/docs/pytorch/pytorch_config.html)!\n"
]
},
{
@@ -168,16 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
- "deeplabcut.create_training_dataset(path_config_file,Shuffles=[1])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### now go edit the pose_cfg.yaml to make display_iters: low (i.e. 10), and save_iters: 500 (for demo's)\n",
- "\n",
- "Now it is the time to start training the network!"
+ "deeplabcut.create_training_dataset(path_config_file, Shuffles=[1])"
]
},
{
@@ -202,54 +181,18 @@
},
"outputs": [],
"source": [
- "#reset in case you started a session before...\n",
- "#tf.reset_default_graph()\n",
- "\n",
- "deeplabcut.train_network(path_config_file, shuffle=1, saveiters=1000, displayiters=10)\n",
- "\n",
- "#this will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it hits the end (default, 1.3M iterations). \n",
- "#Whichever you chose, you will see what looks like an error message, but it's not an error - don't worry....\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Firstly, if the above cell ran, you can stop it with \"stop\" or cntrl-C; you will get a Keyboard Interrupt error (this is fine!)\n",
+ "deeplabcut.train_network(\n",
+ " path_config_file,\n",
+ " shuffle=1,\n",
+ " save_epochs=2,\n",
+ " displayiters=5,\n",
+ ")\n",
"\n",
- "### A couple tips for possible troubleshooting (1): \n",
+ "# This will run until you stop it (CTRL+C), or hit \"STOP\" icon, or when it\n",
+ "# hits the end (default, 200 epochs).\n",
"\n",
- "if you get **permission errors** when you run this step (above), first check if the weights downloaded. As some docker containers might not have privileges for this (it can be user specific). They should be under 'init_weights' (see path in the pose_cfg.yaml file). You can enter the DOCKER in the terminal:"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "see more here: https://github.com/MMathisLab/Docker4DeepLabCut2.0#using-the-docker-for-training-and-video-analysis"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "You can \"cd\" in the terminal to this location! i.e. copy and paste this in: **\"cd usr/local/lib/python3.6/dist-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/\n",
- "\"** \n",
- "\n",
- "And if you type \"ls\" to see the list of files, you should see the resnet:\n",
- "**resnet_v1_50.ckpt**\n",
- "\n",
- "If it is not there, run **\"sudo download.sh\"**\n",
- "then change the permissions: **\"sudo chown yourusername:yourusername resnet_v1_50.ckpt\"**\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Troubleshooting (2): \n",
- "if it appears the training does not start (i.e. \"Starting training...\" does not print immediately),\n",
- "then you have another session running on your GPU. Go check \"nvidia-smi\" and look at the process names. You can only have 1 per GPU!)"
+ "# If you end training before it hits the end, you will see what looks like\n",
+ "# an error message, but it's not an error - don't worry....\n"
]
},
{
@@ -261,7 +204,7 @@
"source": [
"## Start evaluating\n",
"This function evaluates a trained model for a specific shuffle/shuffles at a particular state or all the states on the data set (images)\n",
- "and stores the results as .csv file in a subdirectory under **evaluation-results**"
+ "and stores the results as .csv file in a subdirectory under **evaluation-results-pytorch**"
]
},
{
@@ -277,7 +220,8 @@
"source": [
"deeplabcut.evaluate_network(path_config_file)\n",
"\n",
- "# Here you want to see a low pixel error! Of course, it can only be as good as the labeler, so be sure your labels are good!"
+ "# Here you want to see a low pixel error! Of course, it can only\n",
+ "# be as good as the labeler, so be sure your labels are good!"
]
},
{
@@ -317,8 +261,10 @@
},
"outputs": [],
"source": [
- "videofile_path = ['/home/mackenzie/DEEPLABCUT/DeepLabCut2.0/examples/Reaching-Mackenzie-2018-08-30/videos/MovieS2_Perturbation_noLaser_compressed.avi'] #Enter the list of videos to analyze.\n",
- "deeplabcut.analyze_videos(path_config_file,videofile_path)"
+ "videofile_path = [\n",
+ " \"/home/mackenzie/DEEPLABCUT/DeepLabCut/examples/Reaching-Mackenzie-2018-08-30/videos/MovieS2_Perturbation_noLaser_compressed.avi\"\n",
+ "] # Enter the list of videos to analyze.\n",
+ "deeplabcut.analyze_videos(path_config_file, videofile_path)"
]
},
{
@@ -343,7 +289,7 @@
},
"outputs": [],
"source": [
- "deeplabcut.create_labeled_video(path_config_file,videofile_path)"
+ "deeplabcut.create_labeled_video(path_config_file, videofile_path)"
]
},
{
@@ -369,10 +315,9 @@
"outputs": [],
"source": [
"%matplotlib notebook \n",
- "#for making interactive plots.\n",
- "#deeplabcut.plot_trajectories(path_config_file,videofile_path, plotting=True)\n",
- "\n",
- "deeplabcut.plot_trajectories(path_config_file,videofile_path,showfigures=True)"
+ "# for making interactive plots.\n",
+ "# deeplabcut.plot_trajectories(path_config_file, videofile_path, plotting=True)\n",
+ "deeplabcut.plot_trajectories(path_config_file, videofile_path, showfigures=True)"
]
}
],
@@ -387,7 +332,7 @@
"version": "0.3.2"
},
"kernelspec": {
- "display_name": "Python [default]",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -401,7 +346,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.6"
+ "version": "3.11.11"
},
"varInspector": {
"cols": {
diff --git a/examples/SUPERANIMAL/eval_zeroshot.py b/examples/SUPERANIMAL/eval_zeroshot.py
new file mode 100644
index 0000000000..600e06e988
--- /dev/null
+++ b/examples/SUPERANIMAL/eval_zeroshot.py
@@ -0,0 +1,112 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""SuperAnimal model zero-shot evaluation"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import torch
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.generate_training_dataset import TrainingDatasetMetadata
+from deeplabcut.modelzoo import build_weight_init
+from deeplabcut.pose_estimation_pytorch import DLCLoader
+from deeplabcut.pose_estimation_pytorch.apis.evaluation import evaluate_snapshot
+from deeplabcut.pose_estimation_pytorch.models import PoseModel
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
+
+
+def main(
+ config_path: Path,
+ super_animal: str,
+ shuffle_index: int,
+ device: str,
+ super_animal_model: str = "hrnet_w32",
+ super_animal_detector: str = "fasterrcnn_resnet50_fpn_v2",
+):
+ metadata = TrainingDatasetMetadata.load(config_path, load_splits=True)
+ shuffles = [s for s in metadata.shuffles if s.index == shuffle_index]
+ if len(shuffles) != 1:
+ raise ValueError(
+ "Found multiple shuffles with different train indices but the same index "
+ f"({shuffles}). To run this benchmark, there should only be one such "
+ "shuffle."
+ )
+
+ shuffle = shuffles[0]
+ print(f"Training shuffle: {shuffle.name}")
+ print(f" index: {shuffle.index}")
+ print(f" train fraction: {shuffle.train_fraction}")
+ print(f" train indices: {shuffle.split.train_indices}")
+ print(f" test indices: {shuffle.split.test_indices}")
+ print()
+
+ # edit config to have the desired training fraction
+ af.edit_config(str(config_path), {"TrainingFraction": [shuffle.train_fraction]})
+
+ # Load the config and create a data loader
+ cfg = af.read_config(str(config_path))
+ loader = DLCLoader(
+ config=Path(cfg["project_path"]) / "config.yaml",
+ shuffle=shuffle.index,
+ trainset_index=0,
+ modelprefix="",
+ )
+ loader.evaluation_folder.mkdir(exist_ok=True, parents=True)
+ loader.model_cfg["device"] = device
+
+ # Build the pose model
+ model = PoseModel.build(
+ loader.model_cfg["model"],
+ weight_init=build_weight_init(
+ cfg=cfg,
+ super_animal=super_animal,
+ model_name=super_animal_model,
+ detector_name=super_animal_detector,
+ with_decoder=True,
+ ),
+ )
+
+ # Save the zero-shot snapshot
+ state_dict = {
+ "model": model.state_dict(),
+ "metadata": {
+ "epoch": 0,
+ "metrics": {},
+ "losses": {},
+ },
+ }
+ snapshot_path = loader.model_folder / "zero-shot.pt"
+ torch.save(state_dict, snapshot_path)
+
+ # Evaluate the snapshot
+ evaluate_snapshot(
+ loader=loader,
+ cfg=cfg,
+ scorer=f"{super_animal}-zero-shot",
+ snapshot=Snapshot(best=False, epochs=0, path=snapshot_path),
+ transform=None,
+ plotting=True,
+ show_errors=True,
+ detector_snapshot=None,
+ )
+
+
+if __name__ == "__main__":
+ DATA = Path("/home/niels/datasets/superanimal")
+ CONFIG_PATH = DATA / "openfield-Pranav-2018-08-20" / "config.yaml"
+ SUPER_ANIMAL = "superanimal_topviewmouse"
+ main(
+ config_path=CONFIG_PATH,
+ super_animal=SUPER_ANIMAL,
+ shuffle_index=1001,
+ device="cuda",
+ )
diff --git a/examples/SUPERANIMAL/keypoint_space_conversion.py b/examples/SUPERANIMAL/keypoint_space_conversion.py
new file mode 100644
index 0000000000..ea3d7b4adb
--- /dev/null
+++ b/examples/SUPERANIMAL/keypoint_space_conversion.py
@@ -0,0 +1,36 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Script to convert a dataset for its keypoint space to match the SuperAnimal space"""
+from pathlib import Path
+
+from deeplabcut.modelzoo.generalized_data_converter.datasets import COCOPoseDataset
+from deeplabcut.utils.auxiliaryfunctions import get_deeplabcut_path
+
+
+def main():
+ src_proj_root = Path("/media/data/trimouse_coco_original_shuffle0")
+ conversion_table_path = (
+ Path(get_deeplabcut_path())
+ / "modelzoo"
+ / "conversion_tables"
+ / "conversion_table_topview.csv"
+ )
+ dataset = COCOPoseDataset(str(src_proj_root), "trimouse")
+ dataset.project_with_conversion_table(conversion_table_path)
+ dataset.materialize(
+ src_proj_root.with_name("trimouse_coco_superanimal_shuffle0_shallow_copy"),
+ deepcopy=False,
+ framework="coco",
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/SUPERANIMAL/memory_replay_example.py b/examples/SUPERANIMAL/memory_replay_example.py
new file mode 100644
index 0000000000..8559abfb68
--- /dev/null
+++ b/examples/SUPERANIMAL/memory_replay_example.py
@@ -0,0 +1,92 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Script to fine-tune a SuperAnimal model with memory replay"""
+from pathlib import Path
+
+import deeplabcut
+from deeplabcut.core.engine import Engine
+from deeplabcut.core.weight_init import WeightInitialization
+from deeplabcut.modelzoo.utils import (
+ create_conversion_table,
+ read_conversion_table_from_csv,
+)
+from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
+ get_super_animal_snapshot_path,
+)
+from deeplabcut.utils.pseudo_label import keypoint_matching
+
+
+def main(
+ dlc_proj_root: Path,
+ super_animal_name: str,
+ super_animal_model: str = "hrnet_w32",
+ super_animal_detector: str = "fasterrcnn_resnet50_fpn_v2",
+):
+ config_path = str(dlc_proj_root / "config.yaml")
+ shuffle = 0
+ device = "cuda"
+
+ # keypoint matching before create training dataset
+ # keypoint matching creates pseudo prediction and a conversion table
+ keypoint_matching(
+ config_path,
+ super_animal_name,
+ super_animal_model,
+ super_animal_detector,
+ )
+
+ # keypoint matching creates a memory_replay folder in the root. The conversion table
+ # can be read from there
+ conversion_table_path = dlc_proj_root / "memory_replay" / "conversion_table.csv"
+
+ table = create_conversion_table(
+ config=config_path,
+ super_animal=super_animal_name,
+ project_to_super_animal=read_conversion_table_from_csv(conversion_table_path),
+ )
+
+ weight_init = WeightInitialization(
+ dataset=super_animal_name,
+ snapshot_path=get_super_animal_snapshot_path(
+ dataset=super_animal_name,
+ model_name=super_animal_model,
+ download=True,
+ ),
+ detector_snapshot_path=get_super_animal_snapshot_path(
+ dataset=super_animal_name,
+ model_name=super_animal_detector,
+ download=True,
+ ),
+ conversion_array=table.to_array(),
+ with_decoder=True,
+ )
+
+ deeplabcut.create_training_dataset(
+ config_path,
+ Shuffles=[shuffle],
+ net_type="top_down_hrnet_w32",
+ weight_init=weight_init,
+ engine=Engine.PYTORCH,
+ userfeedback=False,
+ )
+
+ # passing pose_threshold controls the behavior of memory replay. We discard
+ # predictions that are lower than the threshold
+ deeplabcut.train_network(
+ config_path, shuffle=shuffle, device=device, pose_threshold=0.1
+ )
+
+
+if __name__ == "__main__":
+ main(
+ dlc_proj_root=Path("/media/data/myproject"),
+ super_animal_name="superanimal_topviewmouse",
+ )
diff --git a/examples/SUPERANIMAL/superanimal_image_inference.py b/examples/SUPERANIMAL/superanimal_image_inference.py
new file mode 100644
index 0000000000..26fd4e5b77
--- /dev/null
+++ b/examples/SUPERANIMAL/superanimal_image_inference.py
@@ -0,0 +1,30 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from deeplabcut.pose_estimation_pytorch.apis.analyze_images import (
+ superanimal_analyze_images,
+)
+
+
+if __name__ == "__main__":
+ superanimal_name = "superanimal_quadruped"
+ model_name = "hrnet_w32"
+ detector_name = "fasterrcnn_resnet50_fpn_v2"
+ device = "cuda"
+ max_individuals = 3
+
+ ret = superanimal_analyze_images(
+ superanimal_name,
+ model_name,
+ detector_name,
+ "test_rodent_images",
+ max_individuals,
+ "vis_test_rodent_images",
+ )
diff --git a/examples/SUPERANIMAL/video_adapt_example.py b/examples/SUPERANIMAL/video_adapt_example.py
new file mode 100644
index 0000000000..217e7bcfb9
--- /dev/null
+++ b/examples/SUPERANIMAL/video_adapt_example.py
@@ -0,0 +1,31 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Script to run video adaptation"""
+import deeplabcut.modelzoo.video_inference as modelzoo
+
+
+def main():
+ modelzoo.video_inference_superanimal(
+ videos=["/mnt/md0/shaokai/tom_video.mp4"],
+ superanimal_name="superanimal_topviewmouse",
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
+ video_adapt=True,
+ max_individuals=3,
+ pseudo_threshold=0.1,
+ bbox_threshold=0.9,
+ detector_epochs=1,
+ pose_epochs=1,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/openfield-Pranav-2018-10-30/config.yaml b/examples/openfield-Pranav-2018-10-30/config.yaml
index 64c2ce17b2..ed8c31fdf3 100644
--- a/examples/openfield-Pranav-2018-10-30/config.yaml
+++ b/examples/openfield-Pranav-2018-10-30/config.yaml
@@ -2,10 +2,17 @@
Task: openfield
scorer: Pranav
date: Oct30
+multianimalproject:
+identity:
+
# Project path (change when moving around)
project_path: WILL BE AUTOMATICALLY UPDATED BY DEMO CODE
+# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
+engine: pytorch
+
+
# Annotation data set configuration (and individual video cropping parameters)
video_sets:
WILL BE AUTOMATICALLY UPDATED BY DEMO CODE:
@@ -16,23 +23,33 @@ bodyparts:
- rightear
- tailbase
+
+# Fraction of video to start/stop when extracting frames for labeling/refinement
start: 0
stop: 1
numframes2pick: 20
+
# Plotting configuration
+skeleton: []
+skeleton_color: black
pcutoff: 0.4
dotsize: 8
alphavalue: 0.7
colormap: jet
+
# Training,Evaluation and Analysis configuration
TrainingFraction:
- 0.95
iteration: 0
default_net_type: resnet_50
+default_augmenter: imgaug
snapshotindex: -1
+detector_snapshotindex: -1
batch_size: 4
+detector_batch_size: 1
+
# Cropping Parameters (for analysis and outlier frame detection)
cropping: false
@@ -42,8 +59,18 @@ x2: 640
y1: 277
y2: 624
+
# Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
corner2move2:
- 50
- 50
move2corner: true
+
+
+# Conversion tables to fine-tune SuperAnimal weights
+SuperAnimalConversionTables:
+ superanimal_topviewmouse:
+ snout: nose
+ leftear: left_ear
+ rightear: right_ear
+ tailbase: tail_base
diff --git a/examples/test.sh b/examples/test.sh
index 3b3d2a6b90..bad0bdb35d 100755
--- a/examples/test.sh
+++ b/examples/test.sh
@@ -6,7 +6,7 @@ rm -r OUT
cd ..
pip uninstall deeplabcut
python3 setup.py sdist bdist_wheel
-pip install dist/deeplabcut-2.3.11-py3-none-any.whl
+pip install dist/deeplabcut-3.0.0rc6-none-any.whl
cd examples
diff --git a/examples/testscript.py b/examples/testscript.py
index fc984a6905..c09a399d3d 100644
--- a/examples/testscript.py
+++ b/examples/testscript.py
@@ -23,19 +23,18 @@
It produces nothing of interest scientifically.
"""
import os
-import deeplabcut
import platform
-import scipy.io as sio
-import subprocess
+import random
from pathlib import Path
import numpy as np
import pandas as pd
+import scipy.io as sio
+import deeplabcut
+from deeplabcut.core.engine import Engine
from deeplabcut.utils import auxiliaryfunctions
-import random
-
USE_SHELVE = random.choice([True, False])
MODELS = ["resnet_50", "efficientnet-b0", "mobilenet_v2_0.35"]
@@ -43,6 +42,7 @@
if __name__ == "__main__":
task = "TEST" # Enter the name of your experiment Task
scorer = "Alex" # Enter the name of the experimenter/labeler
+ engine = Engine.TF
print("Imported DLC!")
basepath = os.path.dirname(os.path.realpath(__file__))
@@ -92,6 +92,7 @@
print("CREATING-SOME LABELS FOR THE FRAMES")
frames = os.listdir(os.path.join(cfg["project_path"], "labeled-data", videoname))
+ frames = [fn for fn in frames if fn.endswith(".png")]
# As this next step is manual, we update the labels by putting them on the diagonal (fixed for all frames)
for index, bodypart in enumerate(cfg["bodyparts"]):
columnindex = pd.MultiIndex.from_product(
@@ -134,7 +135,7 @@
print("CREATING TRAININGSET")
deeplabcut.create_training_dataset(
- path_config_file, net_type=NET, augmenter_type=augmenter_type
+ path_config_file, net_type=NET, augmenter_type=augmenter_type, engine=engine,
)
# Check the training image paths are correctly stored as arrays of strings
@@ -204,7 +205,7 @@
except: # if ffmpeg is broken/missing
print("using alternative method")
newvideo = os.path.join(cfg["project_path"], "videos", videoname + "short.mp4")
- from moviepy.editor import VideoFileClip, VideoClip
+ from moviepy.editor import VideoClip, VideoFileClip
clip = VideoFileClip(video[0])
clip.reader.initialize()
@@ -231,9 +232,10 @@ def make_frame(t):
)
print("CREATE VIDEO")
- deeplabcut.create_labeled_video(
+ successful = deeplabcut.create_labeled_video(
path_config_file, [newvideo], destfolder=DESTFOLDER, save_frames=True
)
+ assert all(successful), f"Failed to create a labeled video!"
print("Making plots")
deeplabcut.plot_trajectories(path_config_file, [newvideo], destfolder=DESTFOLDER)
@@ -291,7 +293,7 @@ def make_frame(t):
print("CREATING TRAININGSET")
deeplabcut.create_training_dataset(
- path_config_file, net_type=NET, augmenter_type=augmenter_type2
+ path_config_file, net_type=NET, augmenter_type=augmenter_type2, engine=engine
)
cfg = deeplabcut.auxiliaryfunctions.read_config(path_config_file)
@@ -332,7 +334,7 @@ def make_frame(t):
newvideo2 = os.path.join(
cfg["project_path"], "videos", videoname + "short2.mp4"
)
- from moviepy.editor import VideoFileClip, VideoClip
+ from moviepy.editor import VideoClip, VideoFileClip
clip = VideoFileClip(video[0])
clip.reader.initialize()
@@ -362,18 +364,20 @@ def make_frame(t):
)
deeplabcut.filterpredictions(path_config_file, [newvideo2])
- deeplabcut.create_labeled_video(
+ successful = deeplabcut.create_labeled_video(
path_config_file,
[newvideo2],
destfolder=DESTFOLDER,
displaycropped=True,
filtered=True,
)
+ assert all(successful), f"Failed to create a labeled video!"
print("Creating a Johansson video!")
- deeplabcut.create_labeled_video(
+ successful = deeplabcut.create_labeled_video(
path_config_file, [newvideo2], destfolder=DESTFOLDER, keypoints_only=True
)
+ assert all(successful), f"Failed to create a labeled video!"
deeplabcut.plot_trajectories(
path_config_file, [newvideo2], destfolder=DESTFOLDER, filtered=True
@@ -389,6 +393,7 @@ def make_frame(t):
Shuffles=[2],
net_type=NET,
augmenter_type=augmenter_type3,
+ engine=engine,
)
posefile = os.path.join(
@@ -438,6 +443,7 @@ def make_frame(t):
Shuffles=[4, 5],
trainIndices=[trainIndices, trainIndices],
testIndices=[testIndices, testIndices],
+ engine=engine,
)
print("ALL DONE!!! - default cases are functional.")
diff --git a/examples/testscript_multianimal.py b/examples/testscript_multianimal.py
index f1215d13e5..a1c1003c16 100644
--- a/examples/testscript_multianimal.py
+++ b/examples/testscript_multianimal.py
@@ -9,14 +9,17 @@
# Licensed under GNU Lesser General Public License v3.0
#
import os
-import deeplabcut
+import pickle
+import random
+from pathlib import Path
+
import numpy as np
import pandas as pd
-import pickle
+
+import deeplabcut
+from deeplabcut.core.engine import Engine
from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions
from deeplabcut.utils.auxfun_videos import VideoReader
-import random
-from pathlib import Path
MODELS = ["dlcrnet_ms5", "dlcr101_ms5", "efficientnet-b0", "mobilenet_v2_0.35"]
@@ -31,6 +34,7 @@
SCORER = "dlc_team"
NUM_FRAMES = 5
TRAIN_SIZE = 0.8
+ ENGINE = Engine.TF
# NET = "dlcr101_ms5"
NET = "dlcrnet_ms5"
@@ -114,7 +118,7 @@
print("Creating train dataset...")
deeplabcut.create_multianimaltraining_dataset(
- config_path, net_type=NET, crop_size=(200, 200)
+ config_path, net_type=NET, crop_size=(200, 200), engine=ENGINE,
)
print("Train dataset created.")
@@ -134,7 +138,7 @@
print("Editing pose config...")
model_folder = auxiliaryfunctions.get_model_folder(
- TRAIN_SIZE, 1, cfg, cfg["project_path"]
+ TRAIN_SIZE, 1, cfg, engine=ENGINE, modelprefix=cfg["project_path"]
)
pose_config_path = os.path.join(model_folder, "train", "pose_cfg.yaml")
edits = {
@@ -291,7 +295,7 @@
deeplabcut.merge_datasets(config_path) # iteration + 1
print("CREATING TRAININGSET updated training set")
- deeplabcut.create_training_dataset(config_path, net_type=NET)
+ deeplabcut.create_training_dataset(config_path, net_type=NET, engine=ENGINE)
print("Training network...")
deeplabcut.train_network(config_path, maxiters=N_ITER)
@@ -330,6 +334,7 @@
Shuffles=[4, 5],
trainIndices=[trainIndices, trainIndices],
testIndices=[testIndices, testIndices],
+ engine=ENGINE,
)
print("ALL DONE!!! - default multianimal cases are functional.")
diff --git a/examples/testscript_pretrained_models.py b/examples/testscript_pretrained_models.py
index 3b09d0bc26..04fb690cee 100644
--- a/examples/testscript_pretrained_models.py
+++ b/examples/testscript_pretrained_models.py
@@ -45,6 +45,7 @@
analyzevideo=True,
createlabeledvideo=True,
copy_videos=False,
+ engine=deeplabcut.Engine.TF,
) # must leave copy_videos=True
diff --git a/examples/testscript_pytorch_multi_animal.py b/examples/testscript_pytorch_multi_animal.py
new file mode 100644
index 0000000000..db7d797372
--- /dev/null
+++ b/examples/testscript_pytorch_multi_animal.py
@@ -0,0 +1,119 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Testscript for single animal PyTorch projects"""
+from __future__ import annotations
+
+from pathlib import Path
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.compat import Engine
+
+from utils import (
+ cleanup,
+ create_fake_project,
+ log_step,
+ run,
+ SyntheticProjectParameters,
+)
+
+
+def main(
+ net_types: list[str],
+ params: SyntheticProjectParameters,
+ epochs: int = 1,
+ top_down_epochs: int = 1,
+ detector_epochs: int = 1,
+ save_epochs: int = 1,
+ batch_size: int = 1,
+ detector_batch_size: int = 1,
+ max_snapshots_to_keep: int = 5,
+ device: str = "cpu",
+ logger: dict | None = None,
+ create_labeled_videos: bool = False,
+ delete_after_test_run: bool = False,
+) -> None:
+ project_path = Path("synthetic-data-niels-multi-animal").resolve()
+ config_path = project_path / "config.yaml"
+ create_fake_project(path=project_path, params=params)
+
+ engine = Engine.PYTORCH
+ cfg = af.read_config(config_path)
+ trainset_index = 0
+ train_frac = cfg["TrainingFraction"][trainset_index]
+ try:
+ for net_type in net_types:
+ epochs_ = epochs
+ if "top_down" in net_type:
+ epochs_ = top_down_epochs
+ try:
+ run(
+ config_path=config_path,
+ train_fraction=train_frac,
+ trainset_index=trainset_index,
+ net_type=net_type,
+ videos=[str(project_path / "videos" / "video.mp4")],
+ device=device,
+ engine=engine,
+ pytorch_cfg_updates={
+ "train_settings.display_iters": 50,
+ "train_settings.epochs": epochs_,
+ "train_settings.batch_size": batch_size,
+ "runner.device": device,
+ "runner.snapshots.save_epochs": save_epochs,
+ "runner.snapshots.max_snapshots": max_snapshots_to_keep,
+ "detector.train_settings.display_iters": 1,
+ "detector.train_settings.epochs": detector_epochs,
+ "detector.train_settings.batch_size": detector_batch_size,
+ "detector.runner.snapshots.save_epochs": save_epochs,
+ "detector.runner.snapshots.max_snapshots": max_snapshots_to_keep,
+ "logger": logger,
+ },
+ create_labeled_videos=create_labeled_videos,
+ )
+ except Exception as err:
+ log_step(f"FAILED TO RUN {net_type}")
+ log_step(str(err))
+ log_step("Continuing to next model")
+ raise err
+
+ finally:
+ if delete_after_test_run:
+ cleanup(project_path)
+
+
+if __name__ == "__main__":
+ wandb_logger = {
+ "type": "WandbLogger",
+ "project_name": "testscript-dev",
+ "run_name": "test-logging",
+ }
+ main(
+ net_types=["top_down_resnet_50", "resnet_50", "dekr_w32"],
+ params=SyntheticProjectParameters(
+ multianimal=True,
+ num_bodyparts=4,
+ num_individuals=3,
+ num_unique=0,
+ num_frames=25,
+ frame_shape=(256, 256),
+ ),
+ batch_size=2,
+ detector_batch_size=2,
+ epochs=8,
+ top_down_epochs=2,
+ detector_epochs=10,
+ save_epochs=4,
+ max_snapshots_to_keep=2,
+ device="cpu", # "cpu", "cuda:0", "mps"
+ logger=None,
+ create_labeled_videos=True,
+ delete_after_test_run=True,
+ )
diff --git a/examples/testscript_pytorch_single_animal.py b/examples/testscript_pytorch_single_animal.py
new file mode 100644
index 0000000000..a2cbdebc1c
--- /dev/null
+++ b/examples/testscript_pytorch_single_animal.py
@@ -0,0 +1,108 @@
+"""Testscript for single animal PyTorch projects"""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.compat import Engine
+
+from utils import (
+ cleanup,
+ copy_project_for_test,
+ create_fake_project,
+ log_step,
+ run,
+ SyntheticProjectParameters,
+)
+
+
+def main(
+ synthetic_data: bool,
+ net_types: list[str],
+ epochs: int = 1,
+ save_epochs: int = 1,
+ max_snapshots_to_keep: int = 5,
+ batch_size: int = 1,
+ device: str = "cpu",
+ logger: dict | None = None,
+ synthetic_data_params: SyntheticProjectParameters = SyntheticProjectParameters(
+ multianimal=False,
+ num_bodyparts=6,
+ ),
+ create_labeled_videos: bool = False,
+ delete_after_test_run: bool = False,
+) -> None:
+ engine = Engine.PYTORCH
+ if synthetic_data:
+ project_path = Path("synthetic-data-niels-single-animal").resolve()
+ videos = [str(project_path / "videos" / "video.mp4")]
+ create_fake_project(path=project_path, params=synthetic_data_params)
+
+ else:
+ project_path = copy_project_for_test()
+ videos = [str(project_path / "videos" / "m3v1mp4.mp4")]
+
+ config_path = project_path / "config.yaml"
+ cfg = af.read_config(config_path)
+ trainset_index = 0
+ train_frac = cfg["TrainingFraction"][trainset_index]
+ try:
+ for net_type in net_types:
+ try:
+ run(
+ config_path=config_path,
+ train_fraction=train_frac,
+ trainset_index=trainset_index,
+ net_type=net_type,
+ videos=videos,
+ device=device,
+ engine=engine,
+ pytorch_cfg_updates={
+ "train_settings.display_iters": 50,
+ "train_settings.epochs": epochs,
+ "train_settings.batch_size": batch_size,
+ "runner.device": device,
+ "runner.snapshots.save_epochs": save_epochs,
+ "runner.snapshots.max_snapshots": max_snapshots_to_keep,
+ "logger": logger,
+ },
+ create_labeled_videos=create_labeled_videos,
+ )
+
+ except Exception as err:
+ log_step(f"FAILED TO RUN {net_type}")
+ log_step(str(err))
+ log_step("Continuing to next model")
+ raise err
+ finally:
+ if delete_after_test_run:
+ cleanup(project_path)
+
+
+if __name__ == "__main__":
+ wandb_logger = {
+ "type": "WandbLogger",
+ "project_name": "testscript-dev",
+ "run_name": "test-logging",
+ }
+ main(
+ synthetic_data=True,
+ net_types=["cspnext_m", "resnet_50", "hrnet_w32"],
+ batch_size=4,
+ epochs=8,
+ save_epochs=2,
+ max_snapshots_to_keep=2,
+ device="cpu", # "cpu", "cuda:0", "mps"
+ logger=None,
+ synthetic_data_params=SyntheticProjectParameters(
+ multianimal=False,
+ num_bodyparts=4,
+ num_individuals=1,
+ num_unique=0,
+ num_frames=12,
+ frame_shape=(128, 128),
+ ),
+ create_labeled_videos=True,
+ delete_after_test_run=True,
+ )
diff --git a/examples/testscript_superanimal_adaptation.py b/examples/testscript_superanimal_adaptation.py
index 02a660313d..a4d0e15d9d 100644
--- a/examples/testscript_superanimal_adaptation.py
+++ b/examples/testscript_superanimal_adaptation.py
@@ -36,6 +36,8 @@
deeplabcut.video_inference_superanimal(
[video],
superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
videotype=".mp4",
video_adapt=True,
scale_list=scale_list,
diff --git a/examples/testscript_superanimal_create_pretrained_project.py b/examples/testscript_superanimal_create_pretrained_project.py
new file mode 100644
index 0000000000..1ac4e5ed15
--- /dev/null
+++ b/examples/testscript_superanimal_create_pretrained_project.py
@@ -0,0 +1,40 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""
+Testscript for creating a pretrained project from a super animal model
+
+"""
+import glob
+import shutil
+from pathlib import Path
+
+import deeplabcut
+
+if __name__ == "__main__":
+ superanimal_name = "superanimal_quadruped"
+ working_dir = Path(__file__).resolve().parent
+ video_dir = working_dir / "openfield-Pranav-2018-10-30/videos/m3v1mp4short.mp4"
+ project_name = "pretrained"
+
+ deeplabcut.create_pretrained_project(
+ project_name,
+ "max",
+ [str(video_dir)],
+ engine=deeplabcut.Engine.PYTORCH,
+ )
+
+ dirs_to_delete = glob.glob(f"{working_dir}/{project_name}*")
+
+ # Delete directories
+ for directory in dirs_to_delete:
+ shutil.rmtree(directory)
+
+ print("Test passed!")
diff --git a/examples/testscript_superanimal_inference.py b/examples/testscript_superanimal_inference.py
index b4b49c42b6..c0a042e08a 100644
--- a/examples/testscript_superanimal_inference.py
+++ b/examples/testscript_superanimal_inference.py
@@ -31,6 +31,8 @@
deeplabcut.video_inference_superanimal(
video,
superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
videotype=".avi",
scale_list=scale_list,
)
@@ -40,6 +42,8 @@
deeplabcut.video_inference_superanimal(
video,
superanimal_name,
+ model_name="hrnet_w32",
+ detector_name="fasterrcnn_resnet50_fpn_v2",
videotype=".avi",
scale_list=scale_list,
)
diff --git a/examples/testscript_superanimal_transfer_learning.py b/examples/testscript_superanimal_transfer_learning.py
index 1f36f99074..887d7ee174 100644
--- a/examples/testscript_superanimal_transfer_learning.py
+++ b/examples/testscript_superanimal_transfer_learning.py
@@ -11,21 +11,32 @@
"""
Test script for super animal adaptation
"""
-import deeplabcut
import os
+import deeplabcut
+from deeplabcut.modelzoo.weight_initialization import build_weight_init
+
print(deeplabcut.__file__)
if __name__ == "__main__":
superanimal_name = "superanimal_topviewmouse"
basepath = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(basepath, "openfield-Pranav-2018-10-30", "config.yaml")
+ model_name = "hrnet_w32"
+ detector_name = "fasterrcnn_resnet50_fpn_v2"
- deeplabcut.create_training_dataset(config_path, superanimal_name=superanimal_name)
+ weight_init = build_weight_init(
+ cfg=config_path,
+ super_animal=superanimal_name,
+ model_name=model_name,
+ detector_name=detector_name,
+ with_decoder=False,
+ )
+ deeplabcut.create_training_dataset(config_path, weight_init=weight_init)
deeplabcut.train_network(
config_path,
- maxiters=10,
+ epochs=1,
superanimal_name=superanimal_name,
superanimal_transfer_learning=True,
)
diff --git a/examples/utils.py b/examples/utils.py
new file mode 100644
index 0000000000..d4d55ff674
--- /dev/null
+++ b/examples/utils.py
@@ -0,0 +1,449 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import shutil
+import string
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+import cv2
+import deeplabcut
+import deeplabcut.utils.auxiliaryfunctions as af
+import numpy as np
+import pandas as pd
+from deeplabcut.compat import Engine
+from deeplabcut.generate_training_dataset import get_existing_shuffle_indices
+from PIL import Image
+
+
+def log_step(message: Any) -> None:
+ print(100 * "-")
+ print(str(message))
+ print(100 * "-")
+
+
+def cleanup(test_path: Path) -> None:
+ if test_path.exists():
+ shutil.rmtree(test_path)
+
+
+@dataclass(frozen=True)
+class SyntheticProjectParameters:
+ multianimal: bool
+ num_bodyparts: int
+ num_frames: int = 10
+ num_individuals: int = 1
+ num_unique: int = 0
+ identity: bool = False
+ frame_shape: tuple[int, int] = (480, 640)
+
+ def bodyparts(self) -> list[str]:
+ return [i for i in string.ascii_lowercase[: self.num_bodyparts]]
+
+ def unique(self) -> list[str]:
+ return [f"unique_{i}" for i in string.ascii_lowercase[: self.num_unique]]
+
+ def individuals(self) -> list[str]:
+ return [f"animal_{i}" for i in range(self.num_individuals)]
+
+
+def sample_pose_random(
+ gen: np.random.Generator,
+ num_individuals: int,
+ num_bodyparts: int,
+ num_unique: int,
+ img_h: int,
+ img_w: int,
+) -> np.ndarray:
+ """Fully random pose sampling"""
+ xs = gen.choice(img_w, size=(num_individuals, num_bodyparts), replace=False)
+ ys = gen.choice(img_h, size=(num_individuals, num_bodyparts), replace=False)
+ pose = np.stack([xs, ys], axis=-1)
+
+ image_data = pose.reshape(-1)
+ if num_unique > 0:
+ unique_pose = np.stack(
+ [
+ gen.choice(img_w, size=(1, num_unique), replace=False),
+ gen.choice(img_h, size=(1, num_unique), replace=False),
+ ],
+ axis=-1,
+ )
+ image_data = np.concatenate([image_data, unique_pose.reshape(-1)])
+ return image_data
+
+
+def sample_pose_from_center(
+ center_xs: np.ndarray,
+ center_ys: np.ndarray,
+ num_individuals: int,
+ num_bodyparts: int,
+ num_unique: int,
+ radius: int = 25,
+) -> np.ndarray:
+ """Sample keypoints from the center of each individual"""
+ pose = np.zeros((num_individuals, num_bodyparts, 2))
+ for i, (xc, yc) in enumerate(zip(center_xs, center_ys)):
+ if i < num_individuals:
+ x_start, x_end = xc - radius + 1, xc + radius - 1
+ y_start, y_end = yc - radius + 1, yc + radius - 1
+ pose[i, :, 0] = np.linspace(start=x_start, stop=x_end, num=num_bodyparts)
+ pose[i, :, 1] = np.linspace(start=y_start, stop=y_end, num=num_bodyparts)
+
+ image_data = pose.reshape(-1)
+ if num_unique > 0:
+ xc, yc = center_xs[-1], center_ys[-1]
+ x_start, x_end = xc - radius + 1, xc + radius - 1
+ y_start, y_end = yc - radius + 1, yc + radius - 1
+ unique_pose = np.zeros((1, num_unique, 2))
+ unique_pose[0, :, 0] = np.linspace(start=x_start, stop=x_end, num=num_unique)
+ unique_pose[0, :, 1] = np.linspace(start=y_start, stop=y_end, num=num_unique)
+ image_data = np.concatenate([image_data, unique_pose.reshape(-1)])
+ return image_data
+
+
+def gen_fake_data(
+ scorer: str,
+ video_name: str,
+ params: SyntheticProjectParameters,
+) -> pd.DataFrame:
+ kpt_entries = ["x", "y"]
+ col_names = ["scorer", "individuals", "bodyparts", "coords"]
+ col_values = []
+ for i in params.individuals():
+ for b in params.bodyparts():
+ col_values += [(scorer, i, b, entry) for entry in kpt_entries]
+
+ for unique_bpt in params.unique():
+ col_values += [(scorer, "single", unique_bpt, entry) for entry in kpt_entries]
+
+ index_data = []
+ pose_data = []
+ gen = np.random.default_rng(seed=0)
+
+ # sample starting points for each individual
+ img_h, img_w = params.frame_shape[:2]
+ radius = 8
+ center_xs = gen.choice(
+ np.arange(radius, img_w - radius),
+ size=params.num_individuals + 1, # in case unique bodyparts
+ replace=False,
+ )
+ center_ys = gen.choice(
+ np.arange(radius, img_h - radius),
+ size=params.num_individuals + 1, # in case unique bodyparts
+ replace=False,
+ )
+
+ for frame_index in range(params.num_frames):
+ index_data.append(("labeled-data", video_name, f"img{frame_index:04}.png"))
+ pose_data.append(
+ sample_pose_from_center(
+ center_xs,
+ center_ys,
+ num_individuals=params.num_individuals,
+ num_bodyparts=params.num_bodyparts,
+ num_unique=params.num_unique,
+ radius=radius,
+ )
+ )
+ mvt_x = gen.integers(low=-1, high=4, size=center_xs.size)
+ mvt_y = gen.integers(low=-1, high=4, size=center_ys.size)
+ center_xs = np.clip(center_xs + mvt_x, radius, img_w - radius)
+ center_ys = np.clip(center_ys + mvt_y, radius, img_h - radius)
+
+ pose = np.stack(pose_data)
+ pose[params.num_frames // 2, :] = np.nan # add missing row in a frame
+ for idv in range(params.num_individuals):
+ idv_start = 2 * params.num_bodyparts * idv
+ idv_end = 2 * params.num_bodyparts * (idv + 1)
+ if params.num_frames > idv + 1:
+ pose[idv + 1, idv_start:idv_end] = np.nan
+
+ for bpt in range(params.num_bodyparts):
+ frame_idx = 1 + params.num_individuals + bpt
+ idv_idx = bpt % params.num_individuals
+ offset = 2 * params.num_bodyparts * idv_idx
+ bpt_start, bpt_end = 2 * bpt + offset, 2 * (bpt + 1) + offset
+ if params.num_frames + 1 > frame_idx:
+ pose[frame_idx, bpt_start:bpt_end] = np.nan
+
+ return pd.DataFrame(
+ pose,
+ index=pd.MultiIndex.from_tuples(index_data),
+ columns=pd.MultiIndex.from_tuples(col_values, names=col_names),
+ )
+
+
+def gen_fake_image(
+ project_root: Path,
+ row: pd.Series,
+ params: SyntheticProjectParameters,
+ radius: int = 5,
+):
+ img_h, img_w = params.frame_shape
+ image_array = np.zeros((*params.frame_shape, 3), dtype=np.uint8)
+ for i, idv in enumerate(params.individuals()):
+ r = int(255 * (i + 1) / params.num_individuals)
+ if "individuals" in row.index.names:
+ idv_data = row.droplevel("scorer").loc[idv]
+ else:
+ idv_data = row.droplevel("scorer")
+
+ keypoints = idv_data.to_numpy().reshape((-1, 2))
+ if not np.all(np.isnan(keypoints)):
+ idv_center = np.nanmean(keypoints, axis=0)
+ x, y = int(idv_center[0]), int(idv_center[1])
+ xmin, xmax = max(0, x - radius), min(img_w - 1, x + radius)
+ ymin, ymax = max(0, y - radius), min(img_h - 1, y + radius)
+ image_array[ymin:ymax, xmin:xmax, 0] = r
+
+ for j, bpt in enumerate(params.bodyparts()):
+ g = int(255 * (j + 1) / params.num_bodyparts)
+
+ bpt_data = idv_data.loc[bpt]
+ if np.all(~pd.isnull(bpt_data)):
+ x, y = int(bpt_data.x), int(bpt_data.y)
+ xmin, xmax = max(0, x - radius), min(img_w - 1, x + radius)
+ ymin, ymax = max(0, y - radius), min(img_h - 1, y + radius)
+ image_array[ymin:ymax, xmin:xmax, 0] = r
+ image_array[ymin:ymax, xmin:xmax, 1] = g
+
+ if params.num_unique > 0:
+ unique_data = row.droplevel("scorer").loc["single"]
+ for i, unique_bpt in enumerate(params.unique()):
+ bpt_data = unique_data.loc[unique_bpt]
+ if np.all(~pd.isnull(bpt_data)):
+ x, y = int(bpt_data.x), int(bpt_data.y)
+ xmin, xmax = max(0, x - radius), min(img_w - 1, x + radius)
+ ymin, ymax = max(0, y - radius), min(img_h - 1, y + radius)
+ image_array[ymin:ymax, xmin:xmax, 2] = int(
+ 255 * (i + 1) / params.num_unique
+ )
+
+ img = Image.fromarray(image_array)
+ img.save(project_root / Path(*row.name))
+
+
+def generate_video_from_images(image_dir: Path, output_video: Path) -> None:
+ images = [p for p in image_dir.iterdir() if p.is_file() and p.suffix == ".png"]
+ images = sorted(images, key=lambda f: f.stem)
+ if len(images) == 0:
+ return
+
+ height, width, channels = cv2.imread(str(images[0])).shape
+ fourcc = cv2.VideoWriter_fourcc(*"MJPG")
+ out = cv2.VideoWriter(str(output_video), fourcc, 10, (width, height))
+ for img_path in images:
+ img = cv2.imread(str(img_path))
+ out.write(img)
+ out.release()
+
+
+def create_fake_project(path: Path, params: SyntheticProjectParameters) -> None:
+ if path.exists():
+ raise ValueError(f"Cannot create a fake project at an existing path")
+
+ scorer = "synthetic"
+ video_name = "cat"
+ path.mkdir(parents=True, exist_ok=False)
+ config = {
+ "Task": "synthetic",
+ "scorer": scorer,
+ "date": "Nov11",
+ "multianimalproject": params.multianimal,
+ "identity": params.identity,
+ "project_path": str(path / "config.yaml"),
+ "TrainingFraction": [0.8],
+ "iteration": 0,
+ "default_net_type": "resnet_50",
+ "default_augmenter": "default",
+ "default_track_method": "ellipse",
+ "snapshotindex": "all",
+ "batch_size": 8,
+ "pcutoff": 0.6,
+ "video_sets": {
+ str(path / "videos" / video_name): {
+ "crop": (0, params.frame_shape[1], 0, params.frame_shape[0]),
+ },
+ },
+ "start": 0,
+ "stop": 1,
+ "numframes2pick": 10,
+ "dotsize": 4,
+ "alphavalue": 1.0,
+ "colormap": "rainbow",
+ }
+ if not params.multianimal:
+ config["bodyparts"] = params.bodyparts()
+ assert params.num_individuals == 1
+ assert params.num_unique == 0
+ else:
+ config["bodyparts"] = "MULTI!"
+ config["multianimalbodyparts"] = params.bodyparts()
+ config["uniquebodyparts"] = params.unique()
+ config["individuals"] = params.individuals()
+
+ af.write_config(str(path / "config.yaml"), config)
+ image_dir = path / "labeled-data" / video_name
+ image_dir.mkdir(parents=True, exist_ok=False)
+
+ df = gen_fake_data(
+ scorer=scorer,
+ video_name=video_name,
+ params=params,
+ )
+ print("SYNTHETIC DATA:")
+ print(df)
+ print("\n")
+ if not params.multianimal:
+ df.columns = df.columns.droplevel("individuals")
+
+ df.to_hdf(image_dir / f"CollectedData_{scorer}.h5", key="df_with_missing")
+ df.to_csv(image_dir / f"CollectedData_{scorer}.csv")
+
+ for idx in range(params.num_frames):
+ gen_fake_image(path, df.iloc[idx], params=params, radius=5)
+
+ output_video = path / "videos" / "video.mp4"
+ output_video.parent.mkdir(exist_ok=True)
+ generate_video_from_images(image_dir, output_video)
+
+
+def copy_project_for_test() -> Path:
+ data_path = Path.cwd() / "openfield-Pranav-2018-10-30"
+ test_path = Path.cwd() / "pytorch-testscript1234-openfield-Pranav-2018-10-30"
+ if not test_path.exists():
+ shutil.copytree(data_path, test_path)
+
+ project_config = af.read_config(str(test_path / "config.yaml"))
+ videos = list(project_config["video_sets"].keys())
+ video = videos[0]
+ crop = project_config["video_sets"][video]
+ project_config["video_sets"] = {str(test_path / "videos" / "m3v1mp4.mp4"): crop}
+ af.write_config(str(test_path / "config.yaml"), project_config)
+ return test_path
+
+
+def run(
+ config_path: Path,
+ train_fraction: float,
+ trainset_index: int,
+ net_type: str,
+ videos: list[str],
+ device: str,
+ engine: Engine = Engine.PYTORCH,
+ pytorch_cfg_updates: dict | None = None,
+ create_labeled_videos: bool = False,
+) -> None:
+ times = [time.time()]
+ log_step(f"Testing with net type {net_type}")
+ log_step("Creating the training dataset")
+ deeplabcut.create_training_dataset(
+ str(config_path), net_type=net_type, engine=engine
+ )
+ existing_shuffles = get_existing_shuffle_indices(
+ config_path, train_fraction=train_fraction, engine=engine
+ )
+ shuffle_index = existing_shuffles[-1]
+
+ log_step(
+ f"Starting training for train_frac {train_fraction}, shuffle {shuffle_index}"
+ )
+ deeplabcut.train_network(
+ config=str(config_path),
+ shuffle=shuffle_index,
+ trainingsetindex=trainset_index,
+ device=device,
+ pytorch_cfg_updates=pytorch_cfg_updates,
+ )
+ times.append(time.time())
+ log_step(f"Train time: {times[-1] - times[-2]} seconds")
+
+ log_step(
+ f"Starting evaluation for train_frac {train_fraction}, shuffle {shuffle_index}"
+ )
+ deeplabcut.evaluate_network(
+ config=str(config_path),
+ Shuffles=[shuffle_index],
+ trainingsetindex=trainset_index,
+ device=device,
+ plotting=True,
+ per_keypoint_evaluation=True,
+ )
+ times.append(time.time())
+ log_step(f"Evaluation time: {times[-1] - times[-2]} seconds")
+
+ if len(videos) > 0:
+ log_step(f"Analyzing videos for {train_fraction}, shuffle {shuffle_index}")
+ video_kwargs = dict(
+ videos=videos, shuffle=shuffle_index, trainingsetindex=trainset_index
+ )
+ deeplabcut.analyze_videos(
+ str(config_path), **video_kwargs, device=device, auto_track=False
+ )
+ times.append(time.time())
+ log_step(f"Video analysis time: {times[-1] - times[-2]} seconds")
+ log_step(f"Total test time: {times[-1] - times[0]} seconds")
+
+ cfg = af.read_config(config_path)
+ if cfg.get("multianimalproject"):
+ if create_labeled_videos:
+ deeplabcut.create_video_with_all_detections(
+ str(config_path), **video_kwargs
+ )
+
+ # relaxed tracking parameters
+ deeplabcut.convert_detections2tracklets(
+ str(config_path),
+ **video_kwargs,
+ inferencecfg=dict(
+ boundingboxslack=10,
+ iou_threshold=0.2,
+ max_age=5,
+ method="m1",
+ min_hits=1,
+ minimalnumberofconnections=2,
+ pafthreshold=0.1,
+ pcutoff=0.1,
+ topktoretain=3,
+ variant=0,
+ withid=False,
+ ),
+ )
+ deeplabcut.stitch_tracklets(str(config_path), **video_kwargs, min_length=3)
+
+ if create_labeled_videos:
+ log_step(f"Making labeled video, {train_fraction}, shuffle={shuffle_index}")
+ results = deeplabcut.create_labeled_video(
+ config=str(config_path),
+ videos=videos,
+ shuffle=shuffle_index,
+ trainingsetindex=trainset_index,
+ )
+ assert all(results), f"Failed to create some labeled video for {videos}"
+
+
+if __name__ == "__main__":
+ create_fake_project(
+ path=Path("synthetic-data-niels"),
+ params=SyntheticProjectParameters(
+ multianimal=True,
+ num_bodyparts=4,
+ num_individuals=3,
+ num_unique=1,
+ num_frames=50,
+ frame_shape=(128, 256),
+ ),
+ )
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000..e15d2a759c
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,20 @@
+[tool.yapf]
+ based_on_style = "google"
+ indent_width = 4
+
+ [tool.isort]
+ multi_line_output = 3
+ include_trailing_comma = true
+ force_sort_within_sections = false
+ lexicographical = true
+ single_line_exclusions = ['typing']
+ order_by_type = false
+ group_by_package = true
+ line_length = 88
+ skip = [
+ "__init__.py",
+ ]
+[tool.pytest.ini_options]
+markers = [
+ "require_models: mark test as requiring models to run"
+]
\ No newline at end of file
diff --git a/reinstall.sh b/reinstall.sh
index 806f3aeac1..5ffda20f3c 100755
--- a/reinstall.sh
+++ b/reinstall.sh
@@ -1,3 +1,3 @@
pip uninstall deeplabcut
python3 setup.py sdist bdist_wheel
-pip install dist/deeplabcut-2.3.11-py3-none-any.whl
+pip install dist/deeplabcut-3.0.0rc6-py3-none-any.whl
diff --git a/requirements.txt b/requirements.txt
index 9606665251..402239d0ad 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,24 @@
+# novel for pytorch DLC:
+albumentations<=1.4.3
+einops
+pycocotools
+timm
+wandb
+
+# existing:
dlclibrary
ipython
filterpy
ruamel.yaml>=0.15.0
intel-openmp
imageio-ffmpeg
-imgaug==0.4.0
+imgaug>=0.4.0
numba>=0.54.0
-matplotlib<=3.5.2
+matplotlib>=3.3, <3.8.4
networkx>=2.6
numpy>=1.18.5,<2.0.0
pandas>=1.0.1,!=1.5.0
+Pillow>=7.1
pyyaml
scikit-image>=0.17
scikit-learn>=1.0
@@ -17,8 +26,8 @@ scipy>=1.9
statsmodels>=0.11
tensorflow>=2.0,<2.13.0
tables==3.8.0
-tensorpack==0.11
-tf_slim==1.1.0
-torch==1.12
+tensorpack>=0.11
+tf_slim>=1.1.0
+torch>=2.0.0
+torchvision
tqdm
-Pillow>=7.1
diff --git a/setup.py b/setup.py
index 3070f5bbac..b9ce7c31a2 100644
--- a/setup.py
+++ b/setup.py
@@ -1,23 +1,53 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
-DeepLabCut2.0-2.2 Toolbox (deeplabcut.org)
+DeepLabCut2.0-3.0 Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
https://github.com/DeepLabCut/DeepLabCut
Please see AUTHORS for contributors.
-https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""
+from __future__ import annotations
import setuptools
+from pathlib import Path
+
with open("README.md", encoding="utf-8", errors="replace") as fh:
long_description = fh.read()
+def super_animal_config_paths() -> list[str]:
+ config_dirs = [
+ Path("deeplabcut") / "modelzoo" / "model_configs",
+ Path("deeplabcut") / "modelzoo" / "project_configs",
+ ]
+
+ configs = []
+ for subdir in config_dirs:
+ for p in subdir.iterdir():
+ if p.suffix == ".yaml":
+ configs.append(str(p))
+
+ return configs
+
+
+def pytorch_config_paths() -> list[str]:
+ pytorch_configs = []
+ config_dir = Path("deeplabcut") / "pose_estimation_pytorch" / "config"
+ config_subdirs = [p for p in config_dir.iterdir() if p.is_dir()]
+ for subdir in config_subdirs:
+ for p in subdir.iterdir():
+ if p.suffix == ".yaml":
+ pytorch_configs.append(str(p))
+
+ return pytorch_configs
+
+
setuptools.setup(
name="deeplabcut",
- version="2.3.11",
+ version="3.0.0rc6",
author="A. & M.W. Mathis Labs",
author_email="alexander@deeplabcut.org",
description="Markerless pose-estimation of user-defined features with deep learning",
@@ -25,6 +55,9 @@
long_description_content_type="text/markdown",
url="https://github.com/DeepLabCut/DeepLabCut",
install_requires=[
+ "albumentations<=1.4.3",
+ "dlclibrary>=0.0.7",
+ "einops",
"dlclibrary>=0.0.6",
"filterpy>=1.4.4",
"ruamel.yaml>=0.15.0",
@@ -39,11 +72,14 @@
"scikit-learn>=1.0",
"scipy>=1.9",
"statsmodels>=0.11",
- "torch",
+ "tables==3.8.0",
+ "timm",
+ "torch>=2.0.0",
+ "torchvision",
"tqdm",
+ "pycocotools",
"pyyaml",
"Pillow>=7.1",
- "tables==3.8.0",
],
extras_require={
"gui": [
@@ -66,6 +102,7 @@
"tf_slim>=1.1.0",
],
"modelzoo": ["huggingface_hub"],
+ "wandb": ["wandb"],
},
scripts=["deeplabcut/pose_estimation_tensorflow/models/pretrained/download.sh"],
packages=setuptools.find_packages(),
@@ -76,12 +113,13 @@
"deeplabcut/pose_cfg.yaml",
"deeplabcut/inference_cfg.yaml",
"deeplabcut/reid_cfg.yaml",
+ "deeplabcut/modelzoo/models_to_framework.json",
"deeplabcut/pose_estimation_tensorflow/models/pretrained/pretrained_model_urls.yaml",
- "deeplabcut/pose_estimation_tensorflow/superanimal_configs/superquadruped.yaml",
- "deeplabcut/pose_estimation_tensorflow/superanimal_configs/supertopview.yaml",
"deeplabcut/gui/style.qss",
"deeplabcut/gui/media/logo.png",
"deeplabcut/gui/media/dlc_1-01.png",
+ "deeplabcut/gui/media/dlc-pt.png",
+ "deeplabcut/gui/media/dlc-tf.png",
"deeplabcut/gui/assets/logo.png",
"deeplabcut/gui/assets/logo_transparent.png",
"deeplabcut/gui/assets/welcome.png",
@@ -91,8 +129,7 @@
"deeplabcut/gui/assets/icons/new_project2.png",
"deeplabcut/gui/assets/icons/open.png",
"deeplabcut/gui/assets/icons/open2.png",
- "deeplabcut/modelzoo/models.json",
- ],
+ ] + super_animal_config_paths() + pytorch_config_paths(),
)
],
include_package_data=True,
@@ -102,7 +139,7 @@
"Operating System :: OS Independent",
],
entry_points="""[console_scripts]
- dlc=dlc:main""",
+ dlc=deeplabcut.__main__:main""",
)
# https://www.python.org/dev/peps/pep-0440/#compatible-release
diff --git a/tests/conftest.py b/tests/conftest.py
index 1e67d0ca9f..1b29154f98 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,7 +15,7 @@
import shutil
import urllib.request
import zipfile
-from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils
+from deeplabcut.core import inferenceutils
from io import BytesIO
from PIL import Image
from tqdm import tqdm
diff --git a/tests/core/inferenceutils/test_map_computation.py b/tests/core/inferenceutils/test_map_computation.py
new file mode 100644
index 0000000000..889fe3c361
--- /dev/null
+++ b/tests/core/inferenceutils/test_map_computation.py
@@ -0,0 +1,433 @@
+"""Tests mAP computation from inferenceutils"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+
+from deeplabcut.core import inferenceutils
+from deeplabcut.pose_estimation_pytorch.data.utils import bbox_from_keypoints
+
+
+@pytest.mark.parametrize(
+ "ground_truth",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 2],
+ [150.0, 15.0, 2],
+ [202.0, 20.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ ],
+ },
+ ],
+)
+@pytest.mark.parametrize(
+ "predictions",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 0.9],
+ [150.0, 15.0, 0.7],
+ [202.0, 20.0, 0.8],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 0.9],
+ [140.0, 17.0, 0.7],
+ [192.0, 22.0, 0.8],
+ ],
+ [
+ [97.0, 11.0, 0.5],
+ [148.0, 14.0, 0.2],
+ [202.0, 21.0, 0.3],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 0.9],
+ [np.nan, np.nan, 0.0],
+ [192.0, 22.0, 0.8],
+ ],
+ [
+ [97.0, 11.0, 0.5],
+ [148.0, 14.0, 0.2],
+ [202.0, 21.0, 0.3],
+ ],
+ ],
+ },
+ ],
+)
+def test_map_single_image_simple(ground_truth: dict, predictions: dict):
+ gt = {k: np.array(v) for k, v in ground_truth.items()}
+ pred = {k: np.array(v) for k, v in predictions.items()}
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize(
+ "ground_truth",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 2],
+ [150.0, 15.0, 2],
+ [202.0, 20.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [326.0, 236.0, 2],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [0.0, 0.0, 0],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [0, 0, 0],
+ [457.0, 832.0, 2],
+ ],
+ [
+ [452.0, 321.0, 2],
+ [213.0, 387.0, 2],
+ [213.0, 832.0, 2],
+ ],
+ [
+ [253.0, 238.0, 2],
+ [213.0, 238.0, 2],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ ],
+)
+def test_map_single_image_random_errors(ground_truth: dict):
+ rng = np.random.default_rng(seed=0)
+
+ gt = {k: np.array(v) for k, v in ground_truth.items()}
+ pred = {}
+ for k, gt_kpts in gt.items():
+ num_idv, num_bpt = gt_kpts.shape[:2]
+
+ error = rng.integers(low=-30, high=30, size=(num_idv, num_bpt, 2))
+ scores = rng.random(size=(num_idv, num_bpt))
+
+ pred[k] = np.zeros(shape=(num_idv, num_bpt, 3))
+ pred[k][..., :2] = np.clip(gt_kpts[..., :2] + error, 0, 1024)
+ pred[k][..., 2] = scores
+
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize("num_images", [1, 2, 5, 10])
+@pytest.mark.parametrize("num_joints", [2, 5, 8, 20])
+@pytest.mark.parametrize("max_error", [1, 2, 5, 20, 40])
+def test_random_map_computation(num_images, num_joints, max_error):
+ rng = np.random.default_rng(seed=0)
+
+ num_individuals = rng.integers(low=0, high=20, size=(num_images, 2))
+ max_idv = num_individuals.max(initial=0)
+
+ gt = {}
+ pred = {}
+ for i, (gt_idv, pred_idv) in enumerate(num_individuals):
+ # padding needed as we then stack
+ gt_kpts = np.zeros((max_idv, num_joints, 3))
+ pred_kpts = -np.ones((max_idv, num_joints, 3))
+
+ gt_kpts[:gt_idv] = 2 * np.ones((gt_idv, num_joints, 3))
+ gt_kpts[:gt_idv, :, :2] = rng.integers(
+ low=0, high=1024, size=(gt_idv, num_joints, 2)
+ )
+ gt[f"img_{i}"] = gt_kpts
+
+ # set scores
+ pred_kpts[:pred_idv, :, 2] = rng.random(size=(pred_idv, num_joints))
+
+ # predictions that are ground truth + error
+ matched = min(gt_idv, pred_idv)
+ if matched > 0:
+ error = rng.integers(
+ low=-max_error, high=max_error, size=(matched, num_joints, 2)
+ )
+ matched_pred = gt_kpts[:matched, :, :2] + error
+ pred_kpts[:matched, :, :2] = np.clip(matched_pred, 0, 1024)
+
+ # random predictions
+ unmatched = pred_idv - matched
+ if unmatched > 0:
+ pred_kpts[matched:pred_idv, :, :2] = rng.integers(
+ low=0, high=1024, size=(unmatched, num_joints, 2)
+ )
+
+ pred[f"img_{i}"] = pred_kpts
+
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize("num_images", [1, 2, 5, 10])
+@pytest.mark.parametrize("num_joints", [2, 5, 8, 20])
+@pytest.mark.parametrize("max_error", [1, 2, 5, 20, 40])
+def test_random_map_computation_with_missing_kpts(num_images, num_joints, max_error):
+ rng = np.random.default_rng(seed=0)
+
+ num_individuals = rng.integers(low=0, high=20, size=(num_images, 2))
+ max_idv = num_individuals.max(initial=0)
+
+ gt = {}
+ pred = {}
+ for i, (gt_idv, pred_idv) in enumerate(num_individuals):
+ # padding needed as we then stack
+ gt_kpts = np.zeros((max_idv, num_joints, 3))
+ pred_kpts = -np.ones((max_idv, num_joints, 3))
+
+ gt_kpts[:gt_idv] = 2 * np.ones((gt_idv, num_joints, 3))
+ gt_kpts[:gt_idv, :, :2] = rng.integers(
+ low=0, high=1024, size=(gt_idv, num_joints, 2)
+ )
+ gt[f"img_{i}"] = gt_kpts
+
+ # drop some ground truth keypoints
+ gt_vis_mask = rng.random(size=(max_idv, num_joints)) < 0.2
+ gt_kpts[gt_vis_mask, 2] = 0
+
+ # set scores
+ pred_kpts[:pred_idv, :, 2] = rng.random(size=(pred_idv, num_joints))
+
+ # predictions that are ground truth + error
+ matched = min(gt_idv, pred_idv)
+ if matched > 0:
+ error = rng.integers(
+ low=-max_error, high=max_error, size=(matched, num_joints, 2)
+ )
+ matched_pred = gt_kpts[:matched, :, :2] + error
+ pred_kpts[:matched, :, :2] = np.clip(matched_pred, 0, 1024)
+
+ # random predictions
+ unmatched = pred_idv - matched
+ if unmatched > 0:
+ pred_kpts[matched:pred_idv, :, :2] = rng.integers(
+ low=0, high=1024, size=(unmatched, num_joints, 2)
+ )
+
+ pred[f"img_{i}"] = pred_kpts
+
+ _evaluate(gt, pred)
+
+
+def _evaluate(gt: dict[str, np.ndarray], pred: dict[str, np.ndarray]):
+ for k, v in gt.items():
+ print(20 * "-")
+ print(k)
+ print("GT")
+ print(v)
+ print("PR")
+ print(pred[k])
+
+ gt_assemblies = _to_assemblies(gt, ground_truth=True)
+ pred_assemblies = _to_assemblies(pred, ground_truth=False)
+ oks = inferenceutils.evaluate_assembly_greedy(
+ assemblies_gt=gt_assemblies,
+ assemblies_pred=pred_assemblies,
+ oks_sigma=0.1,
+ oks_thresholds=np.linspace(0.5, 0.95, 10),
+ margin=0.0,
+ symmetric_kpts=None,
+ )
+
+ num_joints = gt[list(gt.keys())[0]].shape[1]
+ coco_gt = _to_coco_ground_truth(gt, num_joints, bbox_margin=0)
+ coco_pred = _to_coco_predictions(coco_gt, pred, bbox_margin=0)
+ coco_oks = eval_coco(coco_gt, coco_pred, num_joints)
+ print(20 * "-")
+ print(f"dlc mAP:")
+ for k, v in oks.items():
+ print(k)
+ print(v)
+ print()
+ print(20 * "-")
+ print(f"pycocotools mAP: {coco_oks}")
+ print()
+ assert oks["mAP"] == coco_oks
+
+
+def _to_assemblies(
+ data: dict[str, np.ndarray], ground_truth: bool,
+) -> dict[str, list[inferenceutils.Assembly]]:
+ images = list(data.keys())
+ raw_data = np.stack([data[i] for i in images], axis=0)
+
+ # mask not visible entries
+ mask = raw_data[..., 2] <= 0
+ raw_data[mask] = np.nan
+
+ # set the "score" to 1 for ground truth
+ if ground_truth:
+ raw_data[~mask, 2] = 1
+
+ return {
+ images[i]: assembly
+ for i, assembly in inferenceutils._parse_ground_truth_data(raw_data).items()
+ }
+
+
+def _to_coco_ground_truth(
+ data: dict[str, np.ndarray],
+ num_joints: int,
+ bbox_margin: int = 0,
+ image_size: tuple[int, int] = (1024, 1024),
+) -> dict[str, list[dict]]:
+ w, h = image_size
+ anns, images = [], []
+ for path, image_keypoints in data.items():
+ id_ = len(images) + 1
+ images.append(dict(id=id_, file_name=path, width=w, height=h))
+
+ assert image_keypoints.shape[1] == num_joints
+ for idv_id, kpts in enumerate(image_keypoints):
+ visible = kpts[:, 2] > 0
+ num_keypoints = visible.sum()
+
+ if num_keypoints > 1:
+ bbox = bbox_from_keypoints(
+ keypoints=kpts,
+ image_h=h,
+ image_w=w,
+ margin=bbox_margin,
+ )
+ area = bbox[2].item() * bbox[3].item()
+ anns.append(
+ {
+ "id": len(anns) + 1,
+ "image_id": id_,
+ "category_id": 1,
+ "area": area,
+ "bbox": bbox.tolist(),
+ "keypoints": kpts.reshape(-1).tolist(),
+ "iscrowd": 0,
+ "num_keypoints": num_keypoints,
+ }
+ )
+
+ keypoints = [f"bpt{i}" for i in range(num_joints)]
+ category = dict(id=1, name="animal", supercategory="animal", keypoints=keypoints)
+ return {"annotations": anns, "categories": [category], "images": images}
+
+
+def _to_coco_predictions(
+ ground_truth: dict,
+ predictions: dict[str, np.ndarray],
+ bbox_margin: int = 0,
+ image_size: tuple[int, int] = (1024, 1024),
+) -> list[dict]:
+ w, h = image_size
+ num_joints = len(ground_truth["categories"][0]["keypoints"])
+ path_to_id = {img["file_name"]: img["id"] for img in ground_truth["images"]}
+
+ coco_predictions = []
+ for path, image_keypoints in predictions.items():
+ assert image_keypoints.shape[1] == num_joints
+
+ img_id = path_to_id[path]
+ valid_predictions = [
+ kpt for kpt in image_keypoints if np.any(np.all(~np.isnan(kpt), axis=-1))
+ ]
+ for kpts in valid_predictions:
+ score = float(np.nanmean(kpts[:, 2]).item())
+ kpts = kpts.copy()
+ kpts[:, 2] = 2
+
+ # NaN predictions to infinity
+ kpts[np.isnan(kpts)] = np.inf
+
+ bbox = bbox_from_keypoints(
+ keypoints=kpts,
+ image_h=h,
+ image_w=w,
+ margin=bbox_margin,
+ )
+ area = bbox[2].item() * bbox[3].item()
+ coco_predictions.append(
+ {
+ "image_id": img_id,
+ "category_id": 1,
+ "keypoints": kpts.reshape(-1).tolist(),
+ "bbox": bbox.tolist(),
+ "area": area,
+ "score": score,
+ }
+ )
+
+ return coco_predictions
+
+
+def eval_coco(
+ ground_truth: dict,
+ predictions: list[dict],
+ num_joints: int,
+) -> float | None:
+ try:
+ from pycocotools.coco import COCO
+ from pycocotools.cocoeval import COCOeval
+
+ coco = COCO()
+ coco.dataset["annotations"] = ground_truth["annotations"]
+ coco.dataset["categories"] = ground_truth["categories"]
+ coco.dataset["images"] = ground_truth["images"]
+ coco.createIndex()
+
+ coco_det = coco.loadRes(predictions)
+ coco_eval = COCOeval(coco, coco_det, iouType="keypoints")
+ coco_eval.params.kpt_oks_sigmas = np.array(num_joints * [0.1])
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ return float(coco_eval.stats[0])
+
+ except ModuleNotFoundError as err:
+ print(f"pycocotools is not installed")
diff --git a/tests/core/metrics/test_meitrcs_identity_accuracy.py b/tests/core/metrics/test_meitrcs_identity_accuracy.py
new file mode 100644
index 0000000000..1bde7edb88
--- /dev/null
+++ b/tests/core/metrics/test_meitrcs_identity_accuracy.py
@@ -0,0 +1,218 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests for the scoring methods"""
+import numpy as np
+import pytest
+
+import deeplabcut.core.metrics.identity
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "individuals": ["i1", "i2"],
+ "bodyparts": ["arm"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[2.0, 2.0, 0.8]],
+ [[1.0, 1.0, 0.7]], # x, y, score
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.8, 0.5]],
+ [[0.51, 0.49]],
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[1.0, 1.0, 2]],
+ [[0, 0, 0]], # x, y, visibility
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 1.0,
+ },
+ },
+ {
+ "individuals": ["i1", "i2"],
+ "bodyparts": ["arm"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7]],
+ [[2.0, 2.0, 0.7]], # x, y, score
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.4, 0.6]],
+ [[0.6, 0.4]],
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2]],
+ [[1.0, 1.0, 2]], # x, y, visibility
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 1.0,
+ },
+ },
+ {
+ "individuals": ["i1", "i2"],
+ "bodyparts": ["arm"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7]],
+ [[2.0, 2.0, 0.7]], # x, y, score
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.6, 0.4]],
+ [[0.6, 0.4]], # both assemblies assigned to idv 1
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2]],
+ [[1.0, 1.0, 2]], # x, y, visibility
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 0.5,
+ },
+ },
+ {
+ "individuals": ["i1", "i2"],
+ "bodyparts": ["arm"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7]],
+ [[2.0, 2.0, 0.7]], # x, y, score
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.6, 0.4]],
+ [[0.4, 0.6]], # both assigned to wrong ID
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2]], # x, y, visibility
+ [[1.0, 1.0, 2]],
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 0.0,
+ },
+ },
+ {
+ "individuals": ["i1", "i2"],
+ "bodyparts": ["arm", "leg"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7], [10.0, 10.0, 0.9]],
+ [[100.0, 100.0, 0.9], [90.0, 90.9, 0.8]],
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.7, 0.3], [0.6, 0.2]],
+ [[0.6, 0.3], [0.6, 0.2]], # should not matter, not assigned to GT
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2], [8.0, 8.0, 2]], # x, y, visibility
+ [[-1, -1, 0.0], [-1, -1, 0.0]], # not visible
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 1.0,
+ "leg_accuracy": 1.0,
+ },
+ },
+ {
+ "individuals": ["i1", "i2", "i3"],
+ "bodyparts": ["arm", "leg"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7], [10.0, 10.0, 0.9]],
+ [[100.0, 100.0, 0.9], [90.0, 90.9, 0.8]],
+ [[110.0, 110.0, 0.9], [98.0, 91.9, 0.8]],
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.7, 0.3], [0.6, 0.2]], # assigned to correct ID
+ [[0.6, 0.3], [0.6, 0.2]], # should not matter, not assigned to GT
+ [[0.6, 0.3], [0.6, 0.2]], # should not matter, not assigned to GT
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2], [8.0, 8.0, 2]], # x, y, visibility
+ [[-1, -1, 0.0], [-1, -1, 0.0]], # not visible
+ [[-1, -1, 0.0], [-1, -1, 0.0]], # not visible
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 1.0,
+ "leg_accuracy": 1.0,
+ },
+ },
+ {
+ "individuals": ["i1", "i2", "i3"],
+ "bodyparts": ["arm", "leg"],
+ "predictions": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, 3)
+ [[1.0, 1.0, 0.7], [10.0, 10.0, 0.9]],
+ [[100.0, 100.0, 0.9], [90.0, 90.9, 0.8]],
+ [[110.0, 110.0, 0.9], [98.0, 91.9, 0.8]],
+ ],
+ },
+ "identity_scores": {
+ "img0.png": [ # (num_assemblies, num_bodyparts, num_individuals)
+ [[0.7, 0.3, 0.1], [0.6, 0.2, 0.1]], # assigned to correct ID
+ [[0.1, 0.2, 0.7], [0.4, 0.3, 0.2]], # 1st correct, 2nd wrong
+ [
+ [0.6, 0.3, 0.5],
+ [0.6, 0.2, 0.4],
+ ], # should not matter, not assigned to GT
+ ],
+ },
+ "ground_truth": {
+ "img0.png": [ # (num_individuals, num_bodyparts, 3)
+ [[2.0, 2.0, 2], [8.0, 8.0, 2]], # x, y, visibility
+ [[-1, -1, 0.0], [-1, -1, 0.0]], # not visible
+ [[90.0, 90, 2], [80, 80, 2.0]], # x, y, visibility
+ ]
+ },
+ "accuracy": {
+ "arm_accuracy": 1.0,
+ "leg_accuracy": 0.5,
+ },
+ },
+ ],
+)
+def test_id_accuracy(data) -> None:
+ scores = deeplabcut.core.metrics.identity.compute_identity_scores(
+ individuals=data["individuals"],
+ bodyparts=data["bodyparts"],
+ predictions={k: np.array(v) for k, v in data["predictions"].items()},
+ identity_scores={k: np.array(v) for k, v in data["identity_scores"].items()},
+ ground_truth={k: np.array(v) for k, v in data["ground_truth"].items()},
+ )
+ assert scores == data["accuracy"]
diff --git a/tests/core/metrics/test_metrics_api.py b/tests/core/metrics/test_metrics_api.py
new file mode 100644
index 0000000000..5051794a13
--- /dev/null
+++ b/tests/core/metrics/test_metrics_api.py
@@ -0,0 +1,112 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""General tests for the metrics API"""
+import numpy as np
+import pytest
+from numpy.testing import assert_almost_equal
+
+import deeplabcut.core.metrics as metrics
+
+
+def _get_gt_and_pred_with_constant_err(
+ num_idv: int, num_bpt: int, error: float
+) -> tuple[np.ndarray, np.ndarray]:
+ gt = np.arange(num_idv * num_bpt * 3).astype(float).reshape((num_idv, num_bpt, 3))
+ gt[..., 2] = 2
+ predictions = gt.copy()
+ predictions[..., 2] = 0.9
+ predictions[..., :2] += error
+ return gt, predictions
+
+
+def test_computing_metrics_with_no_predictions():
+ gt = np.arange(5 * 6 * 3).astype(float).reshape((5, 6, 3))
+ gt[..., 2] = 2
+ metrics.compute_metrics(
+ ground_truth={"image": gt},
+ predictions={"image": np.zeros((0, 12, 3))},
+ unique_bodypart_gt=None,
+ unique_bodypart_poses=None,
+ )
+
+
+@pytest.mark.parametrize("error", [0.5, 1, 2])
+def test_computing_metrics_with_constant_error(error):
+ # only works for small errors: otherwise another matching can be found
+ gt, predictions = _get_gt_and_pred_with_constant_err(5, 6, error)
+ results = metrics.compute_metrics(
+ ground_truth={"image": gt},
+ predictions={"image": predictions},
+ unique_bodypart_gt=None,
+ unique_bodypart_poses=None,
+ )
+ assert_almost_equal(results["rmse"], np.sqrt(2) * error)
+ assert_almost_equal(results["rmse_pcutoff"], np.sqrt(2) * error)
+
+
+@pytest.mark.parametrize("error", [0.5, 1, 2])
+def test_metrics_with_unique_with_constant_error(error):
+ # only works for small errors: otherwise another matching can be found
+ gt, predictions = _get_gt_and_pred_with_constant_err(5, 6, error)
+ gt_unique, pred_unique = _get_gt_and_pred_with_constant_err(1, 8, error)
+ results = metrics.compute_metrics(
+ ground_truth={"image": gt},
+ predictions={"image": predictions},
+ unique_bodypart_gt={"image": gt_unique},
+ unique_bodypart_poses={"image": pred_unique},
+ )
+ assert_almost_equal(results["rmse"], np.sqrt(2) * error)
+ assert_almost_equal(results["rmse_pcutoff"], np.sqrt(2) * error)
+
+
+@pytest.mark.parametrize("error", [0.5, 1, 2])
+def test_metrics_per_bpt_with_unique_with_constant_error(error):
+ # only works for small errors: otherwise another matching can be found
+ gt, predictions = _get_gt_and_pred_with_constant_err(5, 6, error)
+ gt_unique, pred_unique = _get_gt_and_pred_with_constant_err(1, 8, error)
+ results = metrics.compute_metrics(
+ ground_truth={"image": gt},
+ predictions={"image": predictions},
+ unique_bodypart_gt={"image": gt_unique},
+ unique_bodypart_poses={"image": pred_unique},
+ per_keypoint_rmse=True,
+ )
+ assert_almost_equal(results["rmse"], np.sqrt(2) * error)
+ assert_almost_equal(results["rmse_pcutoff"], np.sqrt(2) * error)
+
+ for bpt_idx in range(gt.shape[1]):
+ key = f"rmse_keypoint_{bpt_idx}"
+ assert key in results
+ assert_almost_equal(results[key], np.sqrt(2) * error)
+ for bpt_idx in range(gt_unique.shape[1]):
+ key = f"rmse_unique_keypoint_{bpt_idx}"
+ assert key in results
+ assert_almost_equal(results[key], np.sqrt(2) * error)
+
+
+@pytest.mark.parametrize("error", [0.5, 1, 2])
+def test_computing_metrics_single_animal(error):
+ # only works for small errors: otherwise another matching can be found
+ gt = np.arange(6 * 3).astype(float).reshape((1, 6, 3))
+ gt[..., 2] = 2
+ predictions = gt.copy()
+ predictions[..., 2] = 0.9
+ predictions[..., :2] += error
+ results = metrics.compute_metrics(
+ ground_truth={"image": gt},
+ predictions={"image": predictions},
+ single_animal=True,
+ unique_bodypart_gt=None,
+ unique_bodypart_poses=None,
+ )
+ assert_almost_equal(results["rmse"], np.sqrt(2) * error)
+ assert_almost_equal(results["rmse_pcutoff"], np.sqrt(2) * error)
+
diff --git a/tests/core/metrics/test_metrics_map_computation.py b/tests/core/metrics/test_metrics_map_computation.py
new file mode 100644
index 0000000000..18c3296607
--- /dev/null
+++ b/tests/core/metrics/test_metrics_map_computation.py
@@ -0,0 +1,403 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests that mAP computation is correct"""
+
+from __future__ import annotations
+
+import numpy as np
+import pytest
+from numpy.testing import assert_almost_equal
+
+from deeplabcut.core.metrics.api import prepare_evaluation_data
+from deeplabcut.core.metrics.distance_metrics import compute_oks
+from deeplabcut.pose_estimation_pytorch.data.utils import bbox_from_keypoints
+
+
+@pytest.mark.parametrize(
+ "ground_truth",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 2],
+ [150.0, 15.0, 2],
+ [202.0, 20.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ ],
+ },
+ ],
+)
+@pytest.mark.parametrize(
+ "predictions",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 0.9],
+ [150.0, 15.0, 0.7],
+ [202.0, 20.0, 0.8],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 0.9],
+ [140.0, 17.0, 0.7],
+ [192.0, 22.0, 0.8],
+ ],
+ [
+ [97.0, 11.0, 0.5],
+ [148.0, 14.0, 0.2],
+ [202.0, 21.0, 0.3],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 0.9],
+ [np.nan, np.nan, 0.0],
+ [192.0, 22.0, 0.8],
+ ],
+ [
+ [97.0, 11.0, 0.5],
+ [148.0, 14.0, 0.2],
+ [202.0, 21.0, 0.3],
+ ],
+ ],
+ },
+ ],
+)
+def test_map_single_image_simple(ground_truth: dict, predictions: dict):
+ gt = {k: np.array(v) for k, v in ground_truth.items()}
+ pred = {k: np.array(v) for k, v in predictions.items()}
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize(
+ "ground_truth",
+ [
+ {
+ "img0": [
+ [
+ [100.0, 10.0, 2],
+ [150.0, 15.0, 2],
+ [202.0, 20.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [326.0, 236.0, 2],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [0.0, 0.0, 0],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ {
+ "img0": [
+ [
+ [90.0, 12.0, 2],
+ [140.0, 17.0, 2],
+ [192.0, 22.0, 2],
+ ],
+ [
+ [726.0, 325.0, 2],
+ [0, 0, 0],
+ [457.0, 832.0, 2],
+ ],
+ [
+ [452.0, 321.0, 2],
+ [213.0, 387.0, 2],
+ [213.0, 832.0, 2],
+ ],
+ [
+ [253.0, 238.0, 2],
+ [213.0, 238.0, 2],
+ [457.0, 832.0, 2],
+ ],
+ ],
+ },
+ ],
+)
+def test_map_single_image_random_errors(ground_truth: dict):
+ rng = np.random.default_rng(seed=0)
+
+ gt = {k: np.array(v) for k, v in ground_truth.items()}
+ pred = {}
+ for k, gt_kpts in gt.items():
+ num_idv, num_bpt = gt_kpts.shape[:2]
+
+ error = rng.integers(low=-30, high=30, size=(num_idv, num_bpt, 2))
+ scores = rng.random(size=(num_idv, num_bpt))
+
+ pred[k] = np.zeros(shape=(num_idv, num_bpt, 3))
+ pred[k][..., :2] = np.clip(gt_kpts[..., :2] + error, 0, 1024)
+ pred[k][..., 2] = scores
+
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize("num_images", [1, 2, 5, 10])
+@pytest.mark.parametrize("num_joints", [2, 5, 8, 20])
+@pytest.mark.parametrize("max_error", [1, 2, 5, 20, 40])
+def test_random_map_computation(num_images, num_joints, max_error):
+ rng = np.random.default_rng(seed=0)
+
+ num_individuals = rng.integers(low=0, high=20, size=(num_images, 2))
+
+ gt, pred = {}, {}
+ for i, (gt_idv, pred_idv) in enumerate(num_individuals):
+ gt_kpts = 2 * np.ones((gt_idv, num_joints, 3))
+ gt_kpts[..., :2] = rng.integers(low=0, high=1024, size=(gt_idv, num_joints, 2))
+ gt[f"img_{i}"] = gt_kpts
+
+ # create predictions array
+ pred_kpts = np.zeros((pred_idv, num_joints, 3))
+ # set scores
+ pred_kpts[..., 2] = rng.random(size=(pred_idv, num_joints))
+
+ # predictions that are ground truth + error
+ matched = min(gt_idv, pred_idv)
+ if matched > 0:
+ error = rng.integers(
+ low=-max_error, high=max_error, size=(matched, num_joints, 2)
+ )
+ matched_pred = gt_kpts[:matched, :, :2] + error
+ pred_kpts[:matched, :, :2] = np.clip(matched_pred, 0, 1024)
+
+ # random predictions
+ unmatched = pred_idv - matched
+ if unmatched > 0:
+ pred_kpts[matched:, :, :2] = rng.integers(
+ low=0, high=1024, size=(unmatched, num_joints, 2)
+ )
+
+ pred[f"img_{i}"] = pred_kpts
+
+ _evaluate(gt, pred)
+
+
+@pytest.mark.parametrize("num_images", [1, 2, 5, 10])
+@pytest.mark.parametrize("num_joints", [2, 5, 8, 20])
+@pytest.mark.parametrize("max_error", [1, 2, 5, 20, 40])
+def test_random_map_computation_with_missing_kpts(num_images, num_joints, max_error):
+ rng = np.random.default_rng(seed=0)
+ num_individuals = rng.integers(low=0, high=20, size=(num_images, 2))
+
+ gt, pred = {}, {}
+ for i, (gt_idv, pred_idv) in enumerate(num_individuals):
+ gt_kpts = 2 * np.ones((gt_idv, num_joints, 3))
+ gt_kpts[..., :2] = rng.integers(low=0, high=1024, size=(gt_idv, num_joints, 2))
+ gt[f"img_{i}"] = gt_kpts
+
+ # drop some ground truth keypoints
+ gt_vis_mask = rng.random(size=(gt_idv, num_joints)) < 0.2
+ gt_kpts[gt_vis_mask, 2] = 0
+
+ # generate predicted keypoints
+ pred_kpts = np.zeros((pred_idv, num_joints, 3))
+ pred_kpts[:pred_idv, :, 2] = rng.random(size=(pred_idv, num_joints))
+
+ # predictions that are ground truth + error
+ matched = min(gt_idv, pred_idv)
+ if matched > 0:
+ error = rng.integers(
+ low=-max_error, high=max_error, size=(matched, num_joints, 2)
+ )
+ matched_pred = gt_kpts[:matched, :, :2] + error
+ pred_kpts[:matched, :, :2] = np.clip(matched_pred, 0, 1024)
+
+ # random predictions
+ unmatched = pred_idv - matched
+ if unmatched > 0:
+ pred_kpts[matched:, :, :2] = rng.integers(
+ low=0, high=1024, size=(unmatched, num_joints, 2)
+ )
+
+ pred[f"img_{i}"] = pred_kpts
+
+ _evaluate(gt, pred)
+
+
+def _evaluate(gt: dict[str, np.ndarray], pred: dict[str, np.ndarray]):
+ for k, v in gt.items():
+ print(20 * "-")
+ print(k)
+ print("GT")
+ print(v)
+ print("PR")
+ print(pred[k])
+
+ data = prepare_evaluation_data(gt, pred)
+ oks = compute_oks(data, oks_bbox_margin=0)
+
+ num_joints = gt[list(gt.keys())[0]].shape[1]
+ coco_gt = _to_coco_ground_truth(gt, num_joints, bbox_margin=0)
+ coco_pred = _to_coco_predictions(coco_gt, pred, bbox_margin=0)
+ coco_oks = eval_coco(coco_gt, coco_pred, num_joints)
+ print(20 * "-")
+ print(f"dlc mAP:")
+ for k, v in oks.items():
+ print(k)
+ print(v)
+ print(20 * "-")
+ print(f"pycocotools mAP: {coco_oks}")
+ print()
+ dlc_map = oks["mAP"] / 100
+ assert_almost_equal(dlc_map, coco_oks)
+
+
+def _to_coco_ground_truth(
+ data: dict[str, np.ndarray],
+ num_joints: int,
+ bbox_margin: int = 0,
+ image_size: tuple[int, int] = (1024, 1024),
+) -> dict[str, list[dict]]:
+ w, h = image_size
+ anns, images = [], []
+ for path, image_keypoints in data.items():
+ id_ = len(images) + 1
+ images.append(dict(id=id_, file_name=path, width=w, height=h))
+
+ assert image_keypoints.shape[1] == num_joints
+ for idv_id, kpts in enumerate(image_keypoints):
+ visible = kpts[:, 2] > 0
+ num_keypoints = visible.sum()
+
+ if num_keypoints > 1:
+ bbox = bbox_from_keypoints(
+ keypoints=kpts,
+ image_h=h,
+ image_w=w,
+ margin=bbox_margin,
+ )
+ area = bbox[2].item() * bbox[3].item()
+ anns.append(
+ {
+ "id": len(anns) + 1,
+ "image_id": id_,
+ "category_id": 1,
+ "area": area,
+ "bbox": bbox.tolist(),
+ "keypoints": kpts.reshape(-1).tolist(),
+ "iscrowd": 0,
+ "num_keypoints": num_keypoints,
+ }
+ )
+
+ keypoints = [f"bpt{i}" for i in range(num_joints)]
+ category = dict(id=1, name="animal", supercategory="animal", keypoints=keypoints)
+ return {"annotations": anns, "categories": [category], "images": images}
+
+
+def _to_coco_predictions(
+ ground_truth: dict,
+ predictions: dict[str, np.ndarray],
+ bbox_margin: int = 0,
+ image_size: tuple[int, int] = (1024, 1024),
+) -> list[dict]:
+ w, h = image_size
+ num_joints = len(ground_truth["categories"][0]["keypoints"])
+ path_to_id = {img["file_name"]: img["id"] for img in ground_truth["images"]}
+
+ coco_predictions = []
+ for path, image_keypoints in predictions.items():
+ assert image_keypoints.shape[1] == num_joints
+
+ img_id = path_to_id[path]
+ valid_predictions = [
+ kpt for kpt in image_keypoints if np.any(np.all(~np.isnan(kpt), axis=-1))
+ ]
+ for kpts in valid_predictions:
+ score = float(np.nanmean(kpts[:, 2]).item())
+ kpts = kpts.copy()
+ kpts[:, 2] = 2
+
+ # NaN predictions to infinity
+ kpts[np.isnan(kpts)] = np.inf
+
+ bbox = bbox_from_keypoints(
+ keypoints=kpts,
+ image_h=h,
+ image_w=w,
+ margin=bbox_margin,
+ )
+ area = bbox[2].item() * bbox[3].item()
+ coco_predictions.append(
+ {
+ "image_id": img_id,
+ "category_id": 1,
+ "keypoints": kpts.reshape(-1).tolist(),
+ "bbox": bbox.tolist(),
+ "area": area,
+ "score": score,
+ }
+ )
+
+ return coco_predictions
+
+
+def eval_coco(
+ ground_truth: dict,
+ predictions: list[dict],
+ num_joints: int,
+) -> float | None:
+ try:
+ from pycocotools.coco import COCO
+ from pycocotools.cocoeval import COCOeval
+
+ coco = COCO()
+ coco.dataset["annotations"] = ground_truth["annotations"]
+ coco.dataset["categories"] = ground_truth["categories"]
+ coco.dataset["images"] = ground_truth["images"]
+ coco.createIndex()
+
+ coco_det = coco.loadRes(predictions)
+ coco_eval = COCOeval(coco, coco_det, iouType="keypoints")
+ coco_eval.params.kpt_oks_sigmas = np.array(num_joints * [0.1])
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ return float(coco_eval.stats[0])
+
+ except ModuleNotFoundError as err:
+ print(f"pycocotools is not installed")
diff --git a/tests/core/metrics/test_metrics_rmse_computation.py b/tests/core/metrics/test_metrics_rmse_computation.py
new file mode 100644
index 0000000000..850e6bcbbc
--- /dev/null
+++ b/tests/core/metrics/test_metrics_rmse_computation.py
@@ -0,0 +1,377 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests RMSE computation"""
+import numpy as np
+import pytest
+from numpy.testing import assert_almost_equal
+
+from deeplabcut.core.metrics.distance_metrics import (
+ compute_detection_rmse,
+ compute_rmse,
+)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, result",
+ [
+ (
+ [ # ground truth pose
+ [[100.0, 10.0, 2], [150.0, 15.0, 2], [200.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[100.0, 10.0, 0.9], [150.0, 15.0, 0.8], [200.0, 20.0, 0.8]],
+ ],
+ (0, 0),
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[12.0, 10.0, 0.9], [12.0, 10.0, 0.9], [12.0, 10.0, 0.9]],
+ [[22.0, 20.0, 0.9], [22.0, 20.0, 0.9], [22.0, 20.0, 0.9]],
+ ],
+ (2, 2),
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[10.0, 12.0, 0.9], [10.0, 12.0, 0.9], [10.0, 12.0, 0.9]],
+ [[20.0, 22.0, 0.9], [20.0, 22.0, 0.9], [20.0, 22.0, 0.9]],
+ ],
+ (2, 2),
+ ),
+ ],
+)
+def test_rmse_single_image(gt: list, pred: list, result: tuple[float, float]):
+ data = [(np.asarray(gt), np.asarray(pred))]
+ computed_results = compute_rmse(data, False, pcutoff=0.6, oks_bbox_margin=10.0)
+ rmse, rmse_cutoff = computed_results["rmse"], computed_results["rmse_pcutoff"]
+ expected_rmse, expected_rmse_cutoff = result
+ assert_almost_equal(rmse, expected_rmse)
+ assert_almost_equal(rmse_cutoff, expected_rmse_cutoff)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, result",
+ [
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ [[20.0, 22.0, 0.2], [20.0, 22.0, 0.2], [20.0, 22.0, 0.2]],
+ ],
+ (1, 0), # 2 pixel error on half of keypoints, 0 on the other half
+ ),
+ ],
+)
+def test_rmse_pcutoff(gt: list, pred: list, result: tuple[float, float]):
+ data = [(np.asarray(gt), np.asarray(pred))]
+ expected_rmse, expected_rmse_cutoff = result
+
+ computed_results = compute_rmse(data, False, pcutoff=0.6, oks_bbox_margin=10.0)
+ rmse, rmse_cutoff = computed_results["rmse"], computed_results["rmse_pcutoff"]
+ assert_almost_equal(rmse, expected_rmse)
+ assert_almost_equal(rmse_cutoff, expected_rmse_cutoff)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, result",
+ [
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [float("nan"), float("nan"), 0], [10.0, 10.0, 2]],
+ ],
+ [ # predicted pose
+ [[12.0, 10.0, 0.9], [10.0, 10.0, 0.4], [10.0, 10.0, 0.9]],
+ ],
+ (1, 1), # only 2 valid ground truth bodyparts
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [float("nan"), float("nan"), 0]],
+ [[float("nan"), float("nan"), 0], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose, swapped prediction order
+ [[20.0, 20.0, 0.9], [21.0, 20.0, 0.9], [21.0, 20.0, 0.9]],
+ [[15.0, 10.0, 0.4], [15.0, 10.0, 0.4], [10.0, 10.0, 0.9]],
+ ],
+ (3, 1), # only 2 valid GT bodyparts
+ ),
+ ],
+)
+def test_rmse_with_nans(gt: list, pred: list, result: tuple[float, float]):
+ data = [(np.asarray(gt), np.asarray(pred))]
+ expected_rmse, expected_rmse_cutoff = result
+
+ results = compute_rmse(data, False, pcutoff=0.6, oks_bbox_margin=10.0)
+ rmse, rmse_cutoff = results["rmse"], results["rmse_pcutoff"]
+ assert_almost_equal(rmse, expected_rmse)
+ assert_almost_equal(rmse_cutoff, expected_rmse_cutoff)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, data_unique, result",
+ [
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [np.nan, np.nan, 0], [10.0, 10.0, 2]],
+ ],
+ [ # predicted pose
+ [[12.0, 10.0, 0.9], [10.0, 10.0, 0.4], [10.0, 10.0, 0.9]],
+ ],
+ None, # unique data
+ (1, 1), # error 2 on one, 0 on the other; only 2 valid GT
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ [ # predicted pose, perfect detections but mis-assembled
+ [[10.0, 10.0, 0.9], [50.0, 50.0, 0.9], [30.0, 30.0, 0.9]],
+ [[40.0, 40.0, 0.9], [20.0, 20.0, 0.4], [60.0, 60.0, 0.9]],
+ ],
+ None, # unique data
+ (0, 0), # all pose perfect
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ [ # predicted pose, small error in pose and mis-assembled
+ [[12.0, 10.0, 0.9], [52.0, 50.0, 0.9], [32.0, 30.0, 0.9]],
+ [[42.0, 40.0, 0.9], [18.0, 20.0, 0.4], [62.0, 60.0, 0.9]],
+ ],
+ None, # unique data
+ (2, 2), # pixel error of 2 on x-axis for all predictions
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ [ # predicted pose, small error in low-conf pose and mis-assembled
+ [[12.0, 10.0, 0.4], [50.0, 50.0, 0.9], [30.0, 30.0, 0.9]],
+ [[40.0, 40.0, 0.9], [22.0, 20.0, 0.4], [62.0, 60.0, 0.4]],
+ ],
+ None, # unique data
+ (1, 0), # error of 2 on half, 0 on the other half (with good conf)
+ ),
+ ( # more ground truth than detections
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ [[70.0, 70.0, 2], [80.0, 80.0, 2], [90.0, 90.0, 2]],
+ ],
+ [ # predicted pose, no error
+ [[70.0, 70.0, 2], [80.0, 80.0, 2], [90.0, 90.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ None, # unique data
+ (0, 0),
+ ),
+ ( # more detections than GT
+ [ # ground truth pose
+ [[70.0, 70.0, 2], [80.0, 80.0, 2], [90.0, 90.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ [ # predicted pose, no error
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ [[70.0, 70.0, 2], [80.0, 80.0, 2], [90.0, 90.0, 2]],
+ ],
+ None, # unique data
+ (0, 0),
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [np.nan, np.nan, 0], [10.0, 10.0, 2]],
+ ],
+ [ # predicted pose
+ [[12.0, 10.0, 0.9], [10.0, 10.0, 0.4], [10.0, 10.0, 0.9]],
+ ],
+ ( # unique data
+ [[[20, 20, 2], [22, 23, 2]]],
+ [[[20, 20, 0.8], [22, 23, 0.7]]]
+ ),
+ (0.5, 0.5), # error 2 on one, 0 on the other; only 2 valid GT
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [20.0, 20.0, 2], [30.0, 30.0, 2]],
+ [[40.0, 40.0, 2], [50.0, 50.0, 2], [60.0, 60.0, 2]],
+ ],
+ [ # predicted pose, perfect detections but mis-assembled
+ [[10.0, 10.0, 0.9], [50.0, 50.0, 0.9], [30.0, 30.0, 0.9]],
+ [[40.0, 40.0, 0.9], [20.0, 20.0, 0.4], [60.0, 60.0, 0.9]],
+ ],
+ ( # unique data
+ [], # missing ground truth for unique bodyparts
+ [[[20, 20, 0.8], [22, 23, 0.7]]]
+ ),
+ (0, 0), # all pose perfect
+ ),
+ ],
+)
+def test_detection_rmse(gt: list, pred: list, data_unique:tuple[list, list]|None, result: tuple[float, float]):
+ data = [(np.asarray(gt), np.asarray(pred))]
+ data_unique = [(np.asarray(data_unique[0]), np.asarray(data_unique[1]))] if data_unique else None
+ expected_rmse, expected_rmse_cutoff = result
+ rmse, rmse_cutoff = compute_detection_rmse(data, pcutoff=0.6, data_unique=data_unique)
+ assert_almost_equal(rmse, expected_rmse)
+ assert_almost_equal(rmse_cutoff, expected_rmse_cutoff)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, unique_gt, unique_pred, result",
+ [
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ [[20.0, 24.0, 0.2], [20.0, 24.0, 0.2], [20.0, 20.0, 0.2]],
+ ],
+ [ # Unique GT
+ [[10.0, 10.0, 2], [10.0, 10.0, 2]],
+ ],
+ [ # Unique Pred
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ ],
+ # 4 pixel error on 2 keypoints, 0 error on 5 keypoints
+ (1.0, 0.0),
+ ),
+ (
+ [np.zeros((0, 3, 2))], # no GT pose
+ [ # predicted pose
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ ],
+ [ # Unique GT
+ [[10.0, 10.0, 2], [10.0, 10.0, 2]],
+ ],
+ [ # Unique Pred
+ [[15.0, 10.0, 0.5], [11.0, 10.0, 0.9]],
+ ],
+ # 5 pixel error on 1 keypoint, 1 pixel error on the other
+ (3.0, 1.0),
+ ),
+ ],
+)
+def test_rmse_with_unique(
+ gt: list,
+ pred: list,
+ unique_gt: list,
+ unique_pred: list,
+ result: tuple[float, float]
+) -> None:
+ data = [(np.asarray(gt), np.asarray(pred))]
+ data_unique = [(np.asarray(unique_gt), np.asarray(unique_pred))]
+ expected_rmse, expected_rmse_cutoff = result
+
+ results = compute_rmse(
+ data, False, pcutoff=0.6, data_unique=data_unique, oks_bbox_margin=10.0,
+ )
+ rmse, rmse_cutoff = results["rmse"], results["rmse_pcutoff"]
+ assert_almost_equal(rmse, expected_rmse)
+ assert_almost_equal(rmse_cutoff, expected_rmse_cutoff)
+
+
+@pytest.mark.parametrize(
+ "gt, pred, unique_gt, unique_pred, result",
+ [
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ [[20.0, 24.0, 0.2], [20.0, 24.0, 0.2], [20.0, 20.0, 0.2]],
+ ],
+ [ # Unique GT
+ [[10.0, 10.0, 2], [10.0, 10.0, 2]],
+ ],
+ [ # Unique Pred
+ [[10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ ],
+ # 4 pixel error on 2 keypoints, 0 error on 5 keypoints
+ [
+ (1.0, 0.0),
+ [2.0, 2.0, 0.0],
+ [0.0, 0.0]
+ ],
+ ),
+ (
+ [ # ground truth pose
+ [[10.0, 10.0, 2], [10.0, 10.0, 2], [10.0, 10.0, 2]],
+ [[20.0, 20.0, 2], [20.0, 20.0, 2], [20.0, 20.0, 2]],
+ ],
+ [ # predicted pose
+ [[10.0, 12.0, 0.9], [10.0, 10.0, 0.9], [10.0, 10.0, 0.9]],
+ [[20.0, 24.0, 0.7], [20.0, 24.0, 0.6], [20.0, 20.0, 0.8]],
+ ],
+ [ # Unique GT
+ [[10.0, 10.0, 2], [10.0, 10.0, 2]],
+ ],
+ [ # Unique Pred
+ [[12.0, 10.0, 0.9], [11.0, 10.0, 0.9]],
+ ],
+ [ # errors: 3 with 0px, 1 with 1px, 2 with 2px, 2 with 4px => 13/8
+ (1.625, 1.625),
+ [3.0, 2.0, 0.0],
+ [2.0, 1.0]
+ ],
+ ),
+ ],
+)
+def test_rmse_per_bodypart_with_unique(
+ gt: list,
+ pred: list,
+ unique_gt: list,
+ unique_pred: list,
+ result: tuple[tuple[float, float], list[float], list[float]]
+) -> None:
+ data = [(np.asarray(gt), np.asarray(pred))]
+ data_unique = [(np.asarray(unique_gt), np.asarray(unique_pred))]
+ expected_rmse, expected_rmse_cutoff = result[0]
+ bodypart_rmse = result[1]
+ unique_rmse = result[2]
+
+ results = compute_rmse(
+ data,
+ single_animal=False,
+ pcutoff=0.6,
+ data_unique=data_unique,
+ per_keypoint_results=True,
+ oks_bbox_margin=10.0,
+ )
+ assert_almost_equal(results["rmse"], expected_rmse)
+ assert_almost_equal(results["rmse_pcutoff"], expected_rmse_cutoff)
+ for bpt_index, bpt_rmse in enumerate(bodypart_rmse):
+ key = f"rmse_keypoint_{bpt_index}"
+ assert key in results
+ assert_almost_equal(results[key], bpt_rmse)
+
+ for bpt_index, bpt_rmse in enumerate(unique_rmse):
+ key = f"rmse_unique_keypoint_{bpt_index}"
+ assert key in results
+ assert_almost_equal(results[key], bpt_rmse)
diff --git a/tests/generate_training_dataset/test_trainingset_manipulation.py b/tests/generate_training_dataset/test_trainingset_manipulation.py
new file mode 100644
index 0000000000..867d0ea0e4
--- /dev/null
+++ b/tests/generate_training_dataset/test_trainingset_manipulation.py
@@ -0,0 +1,40 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests for deeplabcut/generate_training_dataset/metadata.py"""
+from __future__ import annotations
+
+import pytest
+
+import deeplabcut.generate_training_dataset.trainingsetmanipulation as trainingsetmanipulation
+
+
+@pytest.mark.parametrize(
+ "train_fraction", [1, 2, 5, 17, 24, 29, 34, 47, 50, 53, 61, 68, 75, 90, 95, 97, 99]
+)
+@pytest.mark.parametrize("n_train", [1, 2, 3, 5, 7, 11, 37, 62, 153])
+@pytest.mark.parametrize("n_test", [1, 2, 3, 5, 7, 13, 19, 85, 112])
+def test_compute_padding(train_fraction: int, n_train: int, n_test: int) -> None:
+ """
+ More complete tests can be run with:
+ "train_fraction": list(range(1, 100))
+ "n_train": list(range(1, 200))
+ "n_test": list(range(1, 200))
+
+ This was done locally, but as it's many many tests to run a subset was selected here
+ """
+ train_frac = train_fraction / 100
+ train_pad, test_pad = trainingsetmanipulation._compute_padding(
+ train_frac, n_train, n_test
+ )
+ print()
+ print(train_fraction, n_train, n_test, train_pad, test_pad)
+ frac = round((n_train + train_pad)/(n_train + n_test + train_pad + test_pad), 2)
+ assert train_frac == frac
diff --git a/tests/generate_training_dataset/test_trainset_metadata.py b/tests/generate_training_dataset/test_trainset_metadata.py
new file mode 100644
index 0000000000..e6c150cdf4
--- /dev/null
+++ b/tests/generate_training_dataset/test_trainset_metadata.py
@@ -0,0 +1,409 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests for deeplabcut/generate_training_dataset/metadata.py"""
+from __future__ import annotations
+import pickle
+
+import pytest
+from ruamel.yaml import YAML
+
+import deeplabcut.generate_training_dataset.metadata as metadata
+from deeplabcut.core.engine import Engine
+from deeplabcut.utils import auxiliaryfunctions
+
+SHUFFLE_DATA = [
+ {"name": "pJun17-t50s1", "index": 1, "train_fraction": 0.5, "split": 1, "engine": "torch"},
+ {"name": "pJun17-t50s2", "index": 2, "train_fraction": 0.5, "split": 1, "engine": "tf"},
+ {"name": "pJun17-t60s1", "index": 1, "train_fraction": 0.6, "split": 2, "engine": "torch"},
+ {"name": "pJun17-t60s2", "index": 2, "train_fraction": 0.6, "split": 3, "engine": "torch"},
+]
+SPLITS_DATA = {
+ 1: {"train": [0, 1], "test": [2, 3]},
+ 2: {"train": [0, 1, 2], "test": [3, 4]},
+ 3: {"train": [4, 3, 2], "test": [1, 0]},
+}
+
+BASE_SPLIT = metadata.DataSplit(train_indices=(1, 2), test_indices=(3, 4))
+# Splits that should be equal to the base
+EQ_SPLIT = metadata.DataSplit(train_indices=(1, 2), test_indices=(3, 4))
+# Splits that should not be equal to the base
+ADD_SPLIT = metadata.DataSplit(train_indices=(1, 2, 5), test_indices=(3, 4))
+ADD_SPLIT2 = metadata.DataSplit(train_indices=(1, 2), test_indices=(3, 4, 5))
+SUBS_SPLIT = metadata.DataSplit(train_indices=(1, 3), test_indices=(2, 4))
+DEL_SPLIT = metadata.DataSplit(train_indices=(1,), test_indices=(3, 4))
+DEL_SPLIT2 = metadata.DataSplit(train_indices=(1, 2), test_indices=(3,))
+
+SHUFFLES = {
+ 1: metadata.ShuffleMetadata("pJun17-t50s1", 0.5, 1, Engine.PYTORCH, BASE_SPLIT),
+ 2: metadata.ShuffleMetadata("pJun17-t50s2", 0.5, 2, Engine.PYTORCH, ADD_SPLIT),
+ 3: metadata.ShuffleMetadata("pJun17-t50s3", 0.5, 3, Engine.TF, BASE_SPLIT),
+ 4: metadata.ShuffleMetadata("pJun17-t50s4", 0.5, 4, Engine.PYTORCH, DEL_SPLIT),
+}
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "shuffles": {SHUFFLE_DATA[idx]["name"]: SHUFFLE_DATA[idx] for idx in [0, 1, 2]},
+ "splits": {idx: SPLITS_DATA[idx] for idx in [1, 2]},
+ },
+ {
+ "shuffles": {SHUFFLE_DATA[idx]["name"]: SHUFFLE_DATA[idx] for idx in [0]},
+ "splits": {idx: SPLITS_DATA[idx] for idx in [1, 2]},
+ },
+ ],
+)
+@pytest.mark.parametrize("load_splits", [True, False])
+def test_load_metadata(tmpdir, data: dict, load_splits: bool):
+ """Tests that loading the metadata from files doesn't fail"""
+ # write data to tmp file
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ with open(meta_path, "w") as f:
+ YAML().dump(data, f)
+
+ print(cfg_path)
+ print(meta_path)
+ print(data["shuffles"])
+ print(data["splits"])
+ print()
+
+ for name, s in data["shuffles"].items():
+ split = data["splits"][s["split"]]
+ train, test = split["train"], split["test"]
+ _create_doc_data(
+ cfg, trainset_dir, s["train_fraction"], s["index"], train, test
+ )
+
+ trainset_meta = metadata.TrainingDatasetMetadata.load(
+ str(cfg_path), load_splits=load_splits
+ )
+ for s in trainset_meta.shuffles:
+ print(s)
+
+ assert len(data["shuffles"]) == len(trainset_meta.shuffles)
+
+ for s in trainset_meta.shuffles:
+ shuffle_in = data["shuffles"][s.name]
+ split_idx = data["splits"][shuffle_in["split"]]
+ assert s.train_fraction == shuffle_in["train_fraction"]
+ assert s.engine == Engine(shuffle_in["engine"])
+ if load_splits:
+ assert s.split is not None
+ assert s.split.train_indices == tuple(split_idx["train"])
+ assert s.split.test_indices == tuple(split_idx["test"])
+ else:
+ assert s.split is None
+ s_with_split = s.load_split(cfg, trainset_dir)
+ assert s_with_split.split.train_indices == tuple(split_idx["train"])
+ assert s_with_split.split.test_indices == tuple(split_idx["test"])
+
+
+@pytest.mark.parametrize("data", [
+ {
+ "task": "ch",
+ "date": "Aug1",
+ "shuffles": (SHUFFLES[1], ),
+ "expected": {
+ "shuffles": {
+ SHUFFLES[1].name: {
+ "index": 1, "train_fraction": 0.5, "split": 1, "engine": "pytorch"
+ }
+ },
+ }
+ },
+ {
+ "task": "t",
+ "date": "Jan1",
+ "shuffles": (SHUFFLES[1], SHUFFLES[3]),
+ "expected": {
+ "shuffles": {
+ SHUFFLES[1].name: {
+ "index": 1, "train_fraction": 0.5, "split": 1, "engine": "pytorch"
+ },
+ SHUFFLES[3].name: {
+ "index": 3,
+ "train_fraction": 0.5,
+ "split": 1,
+ "engine": "tensorflow",
+ },
+ },
+ }
+ },
+ {
+ "task": "t",
+ "date": "Jan1",
+ "shuffles": (SHUFFLES[1], SHUFFLES[2]),
+ "expected": {
+ "shuffles": {
+ SHUFFLES[1].name: {
+ "index": 1, "train_fraction": 0.5, "split": 1, "engine": "pytorch"
+ },
+ SHUFFLES[2].name: {
+ "index": 2, "train_fraction": 0.5, "split": 2, "engine": "pytorch"
+ },
+ },
+ },
+ },
+ {
+ "shuffles": (SHUFFLES[1], SHUFFLES[2], SHUFFLES[3]),
+ "expected": {
+ "shuffles": {
+ SHUFFLES[1].name: {
+ "index": 1, "train_fraction": 0.5, "split": 1, "engine": "pytorch"
+ },
+ SHUFFLES[2].name: {
+ "index": 2, "train_fraction": 0.5, "split": 2, "engine": "pytorch"
+ },
+ SHUFFLES[3].name: {
+ "index": 3,
+ "train_fraction": 0.5,
+ "split": 1,
+ "engine": "tensorflow",
+ },
+ },
+ },
+ },
+])
+def test_save_metadata_simple(tmpdir, data):
+ """Tests that saving the metadata creates the expected file"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ trainset_meta = metadata.TrainingDatasetMetadata(cfg, data["shuffles"])
+ print(trainset_meta)
+
+ trainset_meta.save()
+ with open(meta_path, "r") as f:
+ meta = YAML().load(f)
+ print(data)
+ print(meta)
+ assert data["expected"] == meta
+
+
+@pytest.mark.parametrize("shuffles", [
+ [SHUFFLES[i] for i in indices]
+ for indices in [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3, 4], [1, 2, 3, 4]]
+])
+def test_save_metadata(tmpdir, shuffles):
+ """Tests that saving the metadata and reloading it leads to the same instance"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ for s in shuffles:
+ train, test = s.split.train_indices, s.split.test_indices,
+ _create_doc_data(cfg, trainset_dir, s.train_fraction, s.index, train, test)
+
+ trainset_meta = metadata.TrainingDatasetMetadata(cfg, tuple(shuffles))
+ print(trainset_meta)
+ trainset_meta.save()
+ reloaded = metadata.TrainingDatasetMetadata.load(cfg)
+ print(reloaded)
+ print()
+
+ for s in trainset_meta.shuffles:
+ print(s)
+ print()
+ for s in reloaded.shuffles:
+ print(s)
+ print()
+ reloaded_with_splits = [s.load_split(cfg, trainset_dir) for s in reloaded.shuffles]
+ assert len(reloaded.shuffles) == len(trainset_meta.shuffles)
+ assert len(reloaded_with_splits) == len(trainset_meta.shuffles)
+ assert tuple(reloaded_with_splits) == trainset_meta.shuffles
+
+
+def test_add_shuffle(tmpdir):
+ """Tests that a shuffle can be added correctlt"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ trainset_meta = metadata.TrainingDatasetMetadata(cfg, (SHUFFLES[1], ))
+ trainset_meta_added = trainset_meta.add(SHUFFLES[2])
+ assert len(trainset_meta.shuffles) == 1
+ assert len(trainset_meta_added.shuffles) == 2
+ assert trainset_meta_added.shuffles == (SHUFFLES[1], SHUFFLES[2])
+
+
+def test_add_shuffle_twice(tmpdir):
+ """Tests that a shuffle can be added correctlt"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ trainset_meta = metadata.TrainingDatasetMetadata(cfg, (SHUFFLES[1], ))
+ trainset_meta_added = trainset_meta.add(SHUFFLES[2])
+ trainset_meta_added_2 = trainset_meta.add(SHUFFLES[2])
+ assert len(trainset_meta.shuffles) == 1
+ assert trainset_meta.shuffles == (SHUFFLES[1], )
+ assert len(trainset_meta_added.shuffles) == len(trainset_meta_added_2.shuffles)
+ assert trainset_meta_added.shuffles == trainset_meta_added_2.shuffles
+
+
+def test_add_shuffle_sorts_to_correct_order(tmpdir):
+ """Tests that a shuffle can be added correctlt"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ trainset_meta = metadata.TrainingDatasetMetadata(cfg, (SHUFFLES[1], SHUFFLES[3]))
+ trainset_meta_added = trainset_meta.add(SHUFFLES[2])
+ assert len(trainset_meta.shuffles) == 2
+ assert len(trainset_meta_added.shuffles) == 3
+ assert trainset_meta_added.shuffles == (SHUFFLES[1], SHUFFLES[2], SHUFFLES[3])
+
+
+@pytest.mark.parametrize("shuffles", [
+ indices for indices in [[1], [1, 2], [1, 2, 3], [1, 2, 4], [1, 3, 4], [1, 2, 3, 4]]
+])
+@pytest.mark.parametrize("shuffle_to_add", [1, 2, 3, 4])
+def test_add_shuffle(tmpdir, shuffles, shuffle_to_add):
+ """Tests """
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ trainset_meta = metadata.TrainingDatasetMetadata(
+ cfg, tuple([SHUFFLES[i] for i in shuffles])
+ )
+ if shuffle_to_add in shuffles:
+ with pytest.raises(RuntimeError):
+ trainset_meta_added = trainset_meta.add(
+ SHUFFLES[shuffle_to_add], overwrite=False
+ )
+
+ trainset_meta_added = trainset_meta.add(
+ SHUFFLES[shuffle_to_add], overwrite=True
+ )
+ assert len(trainset_meta_added.shuffles) == len(shuffles)
+ assert [s.index for s in trainset_meta_added.shuffles] == shuffles
+ else:
+ trainset_meta_added = trainset_meta.add(
+ SHUFFLES[shuffle_to_add], overwrite=False
+ )
+ indices = [s.index for s in trainset_meta_added.shuffles]
+ assert len(trainset_meta_added.shuffles) == len(shuffles) + 1
+ assert indices == list(sorted(shuffles + [shuffle_to_add]))
+
+
+@pytest.mark.parametrize(
+ "split1, split2, equal",
+ [
+ (BASE_SPLIT, EQ_SPLIT, True),
+ (BASE_SPLIT, ADD_SPLIT, False),
+ (BASE_SPLIT, ADD_SPLIT2, False),
+ (BASE_SPLIT, SUBS_SPLIT, False),
+ (BASE_SPLIT, DEL_SPLIT, False),
+ (BASE_SPLIT, DEL_SPLIT2, False),
+ ],
+)
+def test_data_split_equality(split1, split2, equal):
+ """Tests that equality functions as expected for DataSplits"""
+ print(split1)
+ print(split2)
+ print(equal)
+ assert (split1 == split2) == equal
+
+
+@pytest.mark.parametrize("split_idx", [1, 4, 20, 1000])
+@pytest.mark.parametrize("indices", [(2, 1), (10, 1), (1, 21, 20), (1, 2, 4, 3)])
+@pytest.mark.parametrize("sorted_indices", [(1, 2), (10, 12), (3, 4), (1, 1000, 1200)])
+def test_data_split_requires_sorted(
+ split_idx: int, indices: tuple[int], sorted_indices: tuple[int]
+):
+ """Tests that equality functions as expected for DataSplits"""
+ with pytest.raises(RuntimeError):
+ metadata.DataSplit(
+ train_indices=tuple(indices), test_indices=tuple(sorted_indices)
+ )
+
+ with pytest.raises(RuntimeError):
+ metadata.DataSplit(
+ train_indices=tuple(sorted_indices), test_indices=tuple(indices)
+ )
+
+ with pytest.raises(RuntimeError):
+ metadata.DataSplit(
+ train_indices=tuple(indices), test_indices=tuple(indices)
+ )
+
+ metadata.DataSplit(
+ train_indices=tuple(sorted_indices), test_indices=tuple(sorted_indices)
+ )
+
+
+@pytest.mark.parametrize("shuffles", [
+ (
+ {"idx": 3, "train": [1], "test": [2], "train_fraction": 0.5},
+ ),
+ (
+ {"idx": 1, "train": [1], "test": [2], "train_fraction": 0.5},
+ {"idx": 5, "train": [1, 2, 3], "test": [4, 5], "train_fraction": 0.6},
+ {"idx": 4, "train": [1, 3], "test": [2], "train_fraction": 0.66},
+ ),
+])
+def test_create_metadata_from_shuffles(tmpdir, shuffles):
+ """Tests that equality functions as expected for DataSplits"""
+ cfg, cfg_path, trainset_dir, meta_path = _create_project_with_config(tmpdir)
+ print(trainset_dir)
+ for s in shuffles:
+ doc = f"Documentation_data-ex_{s['train_fraction']}shuffle{s['idx']}.pickle"
+ doc_path = trainset_dir.join(doc)
+ with open(doc_path, "wb") as f:
+ pickle.dump(
+ [[], s["train"], s["test"], s['train_fraction']], f,
+ pickle.HIGHEST_PROTOCOL
+ )
+
+ trainset_metadata = metadata.TrainingDatasetMetadata.create(cfg)
+ print()
+ print(trainset_metadata)
+ assert len(trainset_metadata.shuffles) == len(shuffles)
+
+ for shuffle_data, shuffle in zip(shuffles, trainset_metadata.shuffles):
+ print(shuffle.index)
+ assert shuffle_data["idx"] == shuffle.index
+ assert shuffle_data["train_fraction"] == shuffle.train_fraction
+ assert tuple(shuffle_data["train"]) == shuffle.split.train_indices
+ assert tuple(shuffle_data["test"]) == shuffle.split.test_indices
+ print()
+
+
+def _create_project_with_config(
+ tmp,
+ task: str = "example",
+ date: str = "Feb21",
+ scorer: str = "wayneRooney",
+ iteration: int = 0,
+ engine: str | None = None,
+):
+ project_dir = tmp.mkdir("ex-ample-2024-02-21")
+ cfg = {
+ "Task": task,
+ "date": date,
+ "scorer": scorer,
+ "iteration": iteration,
+ "project_path": str(project_dir),
+ }
+ if engine is not None:
+ cfg["engine"] = engine
+
+ cfg_path = project_dir.join("config.yaml")
+ with open(cfg_path, "w") as file:
+ YAML().dump(cfg, file)
+
+ it = f"iteration-{iteration}"
+ dir_name = "UnaugmentedDataSet_" + task + date
+ trainset_dir = project_dir.mkdir("training-datasets").mkdir(it).mkdir(dir_name)
+
+ meta_path = trainset_dir.join("metadata.yaml")
+ return cfg, cfg_path, trainset_dir, meta_path
+
+
+def _create_doc_data(
+ cfg,
+ trainset_dir,
+ train_frac,
+ shuffle,
+ train_indices,
+ test_indices,
+) -> None:
+ _, doc_path = auxiliaryfunctions.get_data_and_metadata_filenames(
+ trainset_dir, train_frac, shuffle, cfg
+ )
+ auxiliaryfunctions.save_metadata(
+ doc_path, {}, list(train_indices), list(test_indices), train_frac
+ )
diff --git a/tests/pose_estimation_pytorch/apis/test_apis_evaluate.py b/tests/pose_estimation_pytorch/apis/test_apis_evaluate.py
new file mode 100644
index 0000000000..b2de74ba83
--- /dev/null
+++ b/tests/pose_estimation_pytorch/apis/test_apis_evaluate.py
@@ -0,0 +1,481 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from dataclasses import dataclass
+from unittest.mock import Mock, patch
+
+import numpy as np
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.apis as apis
+import deeplabcut.pose_estimation_pytorch.data as data
+
+
+PREDICT = Mock()
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.evaluation.predict", PREDICT)
+@pytest.mark.parametrize("num_individuals", [1, 2, 5])
+@pytest.mark.parametrize(
+ "bodyparts, error",
+ [
+ (["nose", "left_ear"], [5, 10]),
+ (["nose", "left_ear", "right_ear"], [2, 3, 4]),
+ ]
+)
+def test_evaluate_basic(
+ num_individuals: int,
+ bodyparts: list[str],
+ error: list[float],
+) -> None:
+ print()
+ gt, pred = generate_data(1, num_individuals, len(bodyparts), error)
+
+ pose_runner = Mock()
+
+ PREDICT.return_value = {img: {"bodyparts": pose} for img, pose in pred.items()}
+ loader = build_mock_loader(gt, num_individuals, bodyparts)
+ results, preds = apis.evaluate(pose_runner, loader, mode="test")
+ print("results", results)
+ np.testing.assert_almost_equal(results["rmse"], np.mean(error))
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.evaluation.predict", PREDICT)
+@pytest.mark.parametrize("num_individuals", [1, 2, 5])
+@pytest.mark.parametrize(
+ "bodyparts, error",
+ [
+ (["nose", "left_ear"], [5, 10]),
+ (["nose", "left_ear", "right_ear"], [2, 3, 4]),
+ ]
+)
+@pytest.mark.parametrize(
+ "unique_bodyparts, unique_error",
+ [
+ (["top_left"], [2]),
+ (["top_left", "bottom_right"], [2, 3]),
+ ]
+)
+def test_evaluate_with_unique_bodyparts(
+ num_individuals: int,
+ bodyparts: list[str],
+ error: list[float],
+ unique_bodyparts: list[str],
+ unique_error: list[float],
+) -> None:
+ print()
+ num_images = 5
+ gt, pred = generate_data(num_images, num_individuals, len(bodyparts), error)
+ gt_unique, pred_unique = generate_data(
+ num_images, 1, len(unique_bodyparts), unique_error
+ )
+
+ pose_runner = Mock()
+ PREDICT.return_value = {
+ img: {"bodyparts": pose, "unique_bodyparts": pred_unique[img]}
+ for img, pose in pred.items()
+ }
+ loader = build_mock_loader(
+ gt, num_individuals, bodyparts, gt_unique=gt_unique, unique=unique_bodyparts
+ )
+ results, preds = apis.evaluate(pose_runner, loader, mode="test")
+ idv_errors = np.tile(error, (num_individuals, 1)).reshape(-1)
+ expected_rmse = np.mean(np.concatenate([idv_errors, unique_error]))
+ print(num_individuals)
+ print(error)
+ print(idv_errors)
+ print(unique_error)
+ print(np.concatenate([idv_errors, unique_error]))
+ print(expected_rmse)
+ print("results", results)
+ np.testing.assert_almost_equal(results["rmse"], expected_rmse)
+
+
+@dataclass
+class CompTestConfig:
+ num_individuals: int = 1
+ bodyparts: tuple[str, ...] = ("nose", "left_ear")
+ error: tuple[float, ...] = (5, 10)
+ unique_bodyparts: tuple[str, ...] = ("top_left", )
+ unique_error: tuple[float, ...] = (2, )
+ comparison_bodyparts: str | list[str] | None = None
+ expected_error: float = (2 + 5 + 10) / 3
+
+ def num_bpt(self) -> int:
+ return len(self.bodyparts)
+
+ def num_unique(self) -> int:
+ return len(self.unique_bodyparts)
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.evaluation.predict", PREDICT)
+@pytest.mark.parametrize(
+ "cfg",
+ [
+ CompTestConfig(comparison_bodyparts=None),
+ CompTestConfig(comparison_bodyparts="all"),
+ CompTestConfig(comparison_bodyparts=["nose", "left_ear", "top_left"]),
+ CompTestConfig(num_individuals=2, expected_error=(2 + 5 + 5 + 10 + 10) / 5),
+ CompTestConfig(comparison_bodyparts="nose", expected_error=5),
+ CompTestConfig(comparison_bodyparts=["nose"], expected_error=5),
+ CompTestConfig(comparison_bodyparts=["left_ear"], expected_error=10),
+ CompTestConfig(comparison_bodyparts=["nose", "left_ear"], expected_error=7.5),
+ CompTestConfig(comparison_bodyparts="top_left", expected_error=2),
+ CompTestConfig(comparison_bodyparts=["top_left"], expected_error=2),
+ CompTestConfig(
+ unique_bodyparts=("a", "b", "c"),
+ unique_error=(3.0, 4.0, 5.0),
+ comparison_bodyparts=["a", "b", "c"],
+ expected_error=4,
+ ),
+ CompTestConfig(
+ num_individuals=1,
+ unique_bodyparts=("a", "b", "c"),
+ unique_error=(3.0, 4.0, 5.0),
+ comparison_bodyparts=["nose", "a", "b", "c"],
+ expected_error=(5.0 + 3.0 + 4.0 + 5.0) / 4,
+ ),
+ CompTestConfig(
+ num_individuals=7,
+ unique_bodyparts=("a", "b", "c"),
+ unique_error=(3.0, 4.0, 5.0),
+ comparison_bodyparts=["nose", "left_ear", "a", "b"],
+ expected_error=((7 * 5) + (7 * 10) + 3.0 + 4.0) / (7 + 7 + 2),
+ ),
+ ]
+)
+def test_evaluate_with_comparison_bodyparts(cfg: CompTestConfig) -> None:
+ print()
+ num_images = 5
+ gt, pred = generate_data(num_images, cfg.num_individuals, cfg.num_bpt(), cfg.error)
+ gt_unique, pred_unique = generate_data(num_images, 1, cfg.num_unique(), cfg.unique_error)
+
+ pose_runner = Mock()
+ PREDICT.return_value = {
+ img: {"bodyparts": pose, "unique_bodyparts": pred_unique[img]}
+ for img, pose in pred.items()
+ }
+ loader = build_mock_loader(
+ gt,
+ cfg.num_individuals,
+ cfg.bodyparts,
+ gt_unique=gt_unique,
+ unique=cfg.unique_bodyparts,
+ )
+ results, preds = apis.evaluate(
+ pose_runner, loader, mode="test", comparison_bodyparts=cfg.comparison_bodyparts,
+ )
+ print(cfg)
+ print("results", results)
+ np.testing.assert_almost_equal(results["rmse"], cfg.expected_error)
+
+
+@dataclass
+class KeypointData:
+ img: int
+ idv: int
+ bodypart: str
+ gt: tuple[float, float]
+ pred: tuple[float, float]
+ score: float
+
+ def image(self) -> str:
+ return f"image_{self.img:04d}.png"
+
+ def error(self) -> float:
+ return np.linalg.norm(
+ np.asarray(self.gt, dtype=float) - np.asarray(self.pred, dtype=float)
+ ).item()
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.evaluation.predict", PREDICT)
+@pytest.mark.parametrize(
+ "pcutoff", [0.4, 0.6, 0.8, [0.3, 0.5, 0.7]],
+)
+@pytest.mark.parametrize(
+ "keypoints", [
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(10, 10), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 20), pred=(21, 20), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="c", gt=(20, 20), pred=(20, 22), score=0.5),
+ ],
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(10, 10), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 20), pred=(21, 20), score=0.5),
+ KeypointData(img=0, idv=0, bodypart="c", gt=(30, 30), pred=(30, 32), score=0.2),
+ KeypointData(img=0, idv=1, bodypart="a", gt=(40, 10), pred=(41, 10), score=0.7),
+ KeypointData(img=0, idv=1, bodypart="b", gt=(50, 20), pred=(49, 20), score=0.5),
+ KeypointData(img=0, idv=1, bodypart="c", gt=(60, 20), pred=(58, 20), score=0.2),
+ ],
+ ]
+)
+def test_evaluate_with_pcutoff(
+ pcutoff: float | list[float],
+ keypoints: list[KeypointData],
+) -> None:
+ print()
+
+ images = {d.image() for d in keypoints}
+ individuals = list({d.idv for d in keypoints if d.idv != -1})
+ bodyparts = list({d.bodypart for d in keypoints if d.idv != -1})
+ unique_bodyparts = list({d.bodypart for d in keypoints if d.idv == -1})
+
+ num_idv = len(individuals)
+ num_bodyparts = len(bodyparts)
+ num_unique = len(unique_bodyparts)
+
+ gt, pred = {}, {}
+ for img in images:
+ gt[img] = np.zeros((num_idv, num_bodyparts, 3))
+ pred[img] = np.zeros((num_idv, num_bodyparts, 3))
+
+ errors = []
+ errors_cutoff = []
+ for kpt in keypoints:
+ img = kpt.image()
+ bpt = bodyparts.index(kpt.bodypart)
+
+ gt[img][kpt.idv, bpt, :2] = kpt.gt
+ gt[img][kpt.idv, bpt, 2] = 2
+ pred[img][kpt.idv, bpt, :2] = kpt.pred
+ pred[img][kpt.idv, bpt, 2] = kpt.score
+
+ if isinstance(pcutoff, list):
+ bpt_cutoff = pcutoff[bpt]
+ else:
+ bpt_cutoff = pcutoff
+
+ errors.append(kpt.error())
+ if kpt.score >= bpt_cutoff:
+ errors_cutoff.append(kpt.error())
+
+ print(errors)
+ print(errors_cutoff)
+
+ pose_runner = Mock()
+ PREDICT.return_value = {img: {"bodyparts": pose} for img, pose in pred.items()}
+ loader = build_mock_loader(gt, num_idv, bodyparts)
+ results, preds = apis.evaluate(pose_runner, loader, mode="test", pcutoff=pcutoff)
+ print("results", results)
+ np.testing.assert_almost_equal(results["rmse"], np.mean(errors))
+ np.testing.assert_almost_equal(results["rmse_pcutoff"], np.mean(errors_cutoff))
+ if "rmse_detections" in results:
+ np.testing.assert_almost_equal(
+ results["rmse_detections"], np.mean(errors)
+ )
+ np.testing.assert_almost_equal(
+ results["rmse_detections_pcutoff"], np.mean(errors_cutoff)
+ )
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.evaluation.predict", PREDICT)
+@pytest.mark.parametrize(
+ "pcutoff", [
+ 0.4,
+ 0.6,
+ 0.8,
+ [0.3, 0.5, 0.7, 0.4, 0.6],
+ [0.25, 0.43, 0.61, 0.46, 0.92],
+ [0.12, 0.15, 0.92, 0.97, 0.85],
+ [0.92, 0.97, 0.85, 0.12, 0.15],
+ ],
+)
+@pytest.mark.parametrize(
+ "keypoints", [
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(10, 10), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 20), pred=(21, 20), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="c", gt=(20, 20), pred=(20, 22), score=0.5),
+ KeypointData(img=0, idv=-1, bodypart="u1", gt=(20, 20), pred=(20, 22), score=0.5),
+ KeypointData(img=0, idv=-1, bodypart="u2", gt=(20, 20), pred=(20, 22), score=0.3),
+ ],
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(10, 10), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 20), pred=(21, 20), score=0.5),
+ KeypointData(img=0, idv=0, bodypart="c", gt=(30, 30), pred=(30, 32), score=0.2),
+ KeypointData(img=0, idv=1, bodypart="a", gt=(40, 10), pred=(41, 10), score=0.7),
+ KeypointData(img=0, idv=1, bodypart="b", gt=(50, 20), pred=(49, 20), score=0.5),
+ KeypointData(img=0, idv=1, bodypart="c", gt=(60, 20), pred=(58, 20), score=0.2),
+ KeypointData(img=0, idv=-1, bodypart="u1", gt=(2, 3), pred=(3, 3), score=0.7),
+ KeypointData(img=0, idv=-1, bodypart="u2", gt=(20, 20), pred=(20, 22), score=0.9),
+ ],
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(8, 13), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 27), pred=(21, 20), score=0.5),
+ KeypointData(img=0, idv=0, bodypart="c", gt=(30, 36), pred=(30, 32), score=0.2),
+ KeypointData(img=0, idv=-1, bodypart="u1", gt=(2, 3), pred=(3, 3), score=0.7),
+ KeypointData(img=0, idv=-1, bodypart="u2", gt=(20, 20), pred=(20, 22), score=0.9),
+ KeypointData(img=1, idv=0, bodypart="a", gt=(15, 20), pred=(41, 10), score=0.7),
+ KeypointData(img=1, idv=0, bodypart="b", gt=(20, 12), pred=(49, 20), score=0.5),
+ KeypointData(img=1, idv=0, bodypart="c", gt=(17, 32), pred=(58, 20), score=0.2),
+ KeypointData(img=1, idv=-1, bodypart="u1", gt=(37, 4), pred=(3, 3), score=0.7),
+ KeypointData(img=1, idv=-1, bodypart="u2", gt=(12, 6), pred=(20, 22), score=0.9),
+ ],
+ [
+ KeypointData(img=0, idv=0, bodypart="a", gt=(8, 13), pred=(11, 10), score=0.7),
+ KeypointData(img=0, idv=0, bodypart="b", gt=(20, 27), pred=(21, 20), score=0.5),
+ KeypointData(img=0, idv=-1, bodypart="u1", gt=(30, 36), pred=(30, 32), score=0.2),
+ KeypointData(img=0, idv=-1, bodypart="u2", gt=(2, 3), pred=(3, 3), score=0.7),
+ KeypointData(img=0, idv=-1, bodypart="u3", gt=(20, 20), pred=(20, 22), score=0.9),
+ KeypointData(img=1, idv=0, bodypart="a", gt=(15, 20), pred=(41, 10), score=0.7),
+ KeypointData(img=1, idv=0, bodypart="b", gt=(20, 12), pred=(49, 20), score=0.5),
+ KeypointData(img=1, idv=-1, bodypart="u1", gt=(17, 32), pred=(58, 20), score=0.2),
+ KeypointData(img=1, idv=-1, bodypart="u2", gt=(37, 4), pred=(3, 3), score=0.7),
+ KeypointData(img=1, idv=-1, bodypart="u3", gt=(12, 6), pred=(20, 22), score=0.9),
+ ]
+ ]
+)
+def test_evaluate_with_pcutoff_and_unique_bodyparts(
+ pcutoff: float | list[float],
+ keypoints: list[KeypointData],
+) -> None:
+ print()
+
+ images = {d.image() for d in keypoints}
+ individuals = list({d.idv for d in keypoints if d.idv != -1})
+ bodyparts = list({d.bodypart for d in keypoints if d.idv != -1})
+ unique_bodyparts = list({d.bodypart for d in keypoints if d.idv == -1})
+
+ num_idv = len(individuals)
+ num_bodyparts = len(bodyparts)
+ num_unique = len(unique_bodyparts)
+
+ gt, pred, gt_unique, pred_unique = {}, {}, {}, {}
+ for img in images:
+ gt[img] = np.zeros((num_idv, num_bodyparts, 3))
+ pred[img] = np.zeros((num_idv, num_bodyparts, 3))
+ gt_unique[img] = np.zeros((1, num_unique, 3))
+ pred_unique[img] = np.zeros((1, num_unique, 3))
+
+ errors, errors_cutoff = [], []
+ for kpt in keypoints:
+ img = kpt.image()
+ if kpt.idv == -1:
+ idv, bpt = 0, unique_bodyparts.index(kpt.bodypart)
+ pcutoff_idx = bpt + len(bodyparts) # offset by number of bodyparts
+ gt_data, pred_data = gt_unique[img], pred_unique[img]
+ else:
+ idv, bpt = kpt.idv, bodyparts.index(kpt.bodypart)
+ pcutoff_idx = bpt
+ gt_data, pred_data = gt[img], pred[img]
+
+ gt_data[idv, bpt, :2] = kpt.gt
+ gt_data[idv, bpt, 2] = 2
+ pred_data[idv, bpt, :2] = kpt.pred
+ pred_data[idv, bpt, 2] = kpt.score
+
+ if isinstance(pcutoff, list):
+ bpt_cutoff = pcutoff[pcutoff_idx]
+ else:
+ bpt_cutoff = pcutoff
+
+ errors.append(kpt.error())
+ if kpt.score >= bpt_cutoff:
+ errors_cutoff.append(kpt.error())
+
+ print(errors)
+ print(errors_cutoff)
+
+ pose_runner = Mock()
+ PREDICT.return_value = {
+ img: {"bodyparts": pose, "unique_bodyparts": pred_unique[img]}
+ for img, pose in pred.items()
+ }
+ loader = build_mock_loader(gt, num_idv, bodyparts, gt_unique, unique_bodyparts)
+ results, preds = apis.evaluate(pose_runner, loader, mode="test", pcutoff=pcutoff)
+
+ print("results", results)
+ np.testing.assert_almost_equal(results["rmse"], np.mean(errors))
+ np.testing.assert_almost_equal(results["rmse_pcutoff"], np.mean(errors_cutoff))
+ if "rmse_detections" in results:
+ np.testing.assert_almost_equal(
+ results["rmse_detections"], np.mean(errors)
+ )
+ np.testing.assert_almost_equal(
+ results["rmse_detections_pcutoff"], np.mean(errors_cutoff)
+ )
+
+
+def generate_data(
+ num_images: int,
+ num_individuals: int,
+ num_bodyparts: int,
+ error: list[float] | tuple[float, ...] | np.ndarray,
+ cutoffs: list[float] | tuple[float, ...] | np.ndarray | None = None,
+ error_cutoff: list[float] | tuple[float, ...] | np.ndarray | None = None,
+) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
+ num_elems = num_individuals * num_bodyparts
+ shape = num_individuals, num_bodyparts, 3
+ error = np.asarray(error)
+ coord_error = (np.sqrt(2) / 2) * error
+
+ gt, pred = {}, {}
+ for img in range(num_images):
+ gt_pose = 100 * np.arange(3 * num_elems, dtype=float).reshape(shape)
+ gt_pose[..., 2] = 2
+ gt[f"img_{img:04d}.png"] = gt_pose
+
+ pred_pose = np.ones(shape, dtype=float)
+ pred_pose[..., :2] = gt_pose[..., :2]
+ pred_pose[:, :, 0] += coord_error
+ pred_pose[:, :, 1] += coord_error
+ pred[f"img_{img:04d}.png"] = pred_pose
+
+ if error_cutoff is not None and cutoffs is not None:
+ for img in range(num_images):
+ gt_pose = 100 * np.arange(3 * num_elems, dtype=float).reshape(shape)
+ gt_pose[..., 2] = 2
+ gt[f"img_{num_images + img:04d}.png"] = gt_pose
+
+ pred_pose = np.ones(shape, dtype=float)
+ pred_pose[..., :2] = gt_pose[..., :2]
+ pred_pose[..., 2] = cutoffs
+ pred_pose[:, :, 0] += coord_error
+ pred_pose[:, :, 1] += coord_error
+ pred[f"img_{num_images + img:04d}.png"] = pred_pose
+
+ return gt, pred
+
+
+def build_mock_loader(
+ gt: dict[str, np.ndarray],
+ num_individuals: int,
+ bodyparts: list[str] | tuple[str, ...],
+ gt_unique: dict[str, np.ndarray] | None = None,
+ unique: list[str] | tuple[str, ...] | None = None,
+) -> Mock:
+ if unique is None:
+ unique = []
+
+ def _gt(mode: str, unique_bodypart: bool = False) -> dict[str, np.ndarray]:
+ if unique_bodypart:
+ print("LOADING UNIQUE GT")
+ return gt_unique
+ print("LOADING GT")
+ return gt
+
+ individuals = [f"animal_{i:03d}" for i in range(num_individuals)]
+ loader = Mock()
+ loader.get_dataset_parameters.return_value = data.PoseDatasetParameters(
+ bodyparts=bodyparts,
+ unique_bpts=unique,
+ individuals=individuals,
+ )
+ loader.ground_truth_keypoints = _gt
+ loader.model_cfg = {
+ "metadata": {
+ "bodyparts": bodyparts,
+ "unique_bodyparts": unique,
+ "individuals": individuals,
+ "with_identity": False,
+ },
+ "train_settings": {},
+ }
+ return loader
diff --git a/tests/pose_estimation_pytorch/apis/test_apis_export.py b/tests/pose_estimation_pytorch/apis/test_apis_export.py
new file mode 100644
index 0000000000..89b93e2ac3
--- /dev/null
+++ b/tests/pose_estimation_pytorch/apis/test_apis_export.py
@@ -0,0 +1,310 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests exporting models"""
+import copy
+import shutil
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.apis.export as export
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.pose_estimation_pytorch import Task
+from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
+
+
+@pytest.fixture()
+def project_dir(tmp_path_factory) -> Path:
+ project_dir = tmp_path_factory.mktemp("tmp-project")
+ print(f"\nTemporary project directory:")
+ print(str(project_dir))
+ print("---")
+ yield project_dir
+ shutil.rmtree(str(project_dir))
+
+
+def _mock_multianimal_project(project_dir: Path):
+ video_dir = project_dir / "videos"
+ video_dir.mkdir(exist_ok=True)
+
+ cfg_file, yaml_file = af.create_config_template(multianimal=True)
+ cfg_file["Task"] = "mock"
+ cfg_file["scorer"] = "mock"
+ cfg_file["video_sets"] = {str(video_dir / "vid.mp4"): dict(crop="0, 640, 0, 480")}
+ cfg_file["project_path"] = str(project_dir)
+ cfg_file["individuals"] = ["a", "b"]
+ cfg_file["uniquebodyparts"] = []
+ cfg_file["multianimalbodyparts"] = ["k1", "k2", "k3"]
+ cfg_file["bodyparts"] = "MULTI!"
+
+ with open(project_dir / "config.yaml", "w") as f:
+ yaml_file.dump(cfg_file, f)
+
+
+def _make_mock_loader(
+ project_path: Path,
+ project_task: str,
+ project_iteration: int,
+ model_folder: Path,
+ net_type: str,
+ pose_task: Task,
+ default_snapshot_index: int | str,
+ default_detector_snapshot_index: int | str,
+) -> Mock:
+ loader = Mock()
+ loader.project_path = project_path
+ loader.model_folder = model_folder
+ loader.pose_task = pose_task
+ loader.shuffle = 0
+
+ loader.project_cfg = dict(
+ project_path=str(project_path),
+ Task=project_task,
+ date="Jan12",
+ TrainingFraction=[0.95],
+ snapshotindex=default_snapshot_index,
+ detector_snapshotindex=default_detector_snapshot_index,
+ iteration=project_iteration,
+ )
+ loader.model_cfg = dict(
+ net_type=net_type,
+ metadata=dict(
+ project_path=str(project_path),
+ pose_config_path=str(loader.model_folder / "pytorch_config.yaml"),
+ ),
+ weight_init=None,
+ resume_training_from=None,
+ )
+ if pose_task == Task.TOP_DOWN:
+ loader.model_cfg["detector"] = dict(resume_training_from=None)
+
+ return loader
+
+
+def _get_export_model_data(
+ project_dir: Path,
+ num_snapshots: int,
+ task: Task,
+ project_iteration: int = 0,
+):
+ _mock_multianimal_project(project_dir)
+
+ model_dir = Path(project_dir) / f"iteration-{project_iteration}" / "fake-shuffle-0"
+ model_dir.mkdir(exist_ok=True, parents=True)
+ snapshots = []
+ snapshot_data = []
+ for i in range(num_snapshots):
+ snapshot = dict(model=dict(idx=i))
+ snapshot_path = model_dir / f"snapshot-{i:03}.pt"
+ torch.save(snapshot, snapshot_path)
+ snapshots.append(Snapshot(best=False, epochs=i, path=snapshot_path))
+ snapshot_data.append(snapshot)
+
+ detector_snapshots = []
+ detector_data = []
+ if task == Task.TOP_DOWN:
+ for i in range(num_snapshots):
+ snapshot = dict(model=dict(idx=i))
+ snapshot_path = model_dir / f"snapshot-detector-{i:03}.pt"
+ torch.save(snapshot, snapshot_path)
+ detector_data.append(snapshot)
+ detector_snapshots.append(
+ Snapshot(best=False, epochs=i, path=snapshot_path)
+ )
+
+ mock_loader = _make_mock_loader(
+ project_path=project_dir,
+ project_task="mock",
+ project_iteration=project_iteration,
+ model_folder=model_dir,
+ net_type="fake-net",
+ pose_task=task,
+ default_snapshot_index=-1,
+ default_detector_snapshot_index=-1,
+ )
+ return mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data
+
+
+@pytest.mark.parametrize(
+ "task, num_snapshots, idx, detector_idx",
+ [
+ (Task.BOTTOM_UP, 10, 0, None),
+ (Task.BOTTOM_UP, 10, 5, None),
+ (Task.BOTTOM_UP, 10, -1, None),
+ (Task.TOP_DOWN, 10, 0, 0),
+ (Task.TOP_DOWN, 10, -1, 0),
+ (Task.TOP_DOWN, 10, -1, 5),
+ (Task.TOP_DOWN, 10, -1, -1),
+ ],
+)
+def test_export_model(
+ project_dir,
+ task: Task,
+ num_snapshots: int,
+ idx: int,
+ detector_idx: int | None,
+):
+ test_data = _get_export_model_data(project_dir, num_snapshots, task)
+ mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
+
+ def get_mock_loader(*args, **kwargs):
+ return mock_loader
+
+ with patch(
+ "deeplabcut.pose_estimation_pytorch.apis.export.dlc3_data.DLCLoader",
+ get_mock_loader,
+ ):
+ # export the model
+ export.export_model(
+ project_dir / "config.yaml",
+ snapshotindex=idx,
+ detector_snapshot_index=detector_idx,
+ )
+
+ # check that the correct snapshot was exported
+ snapshot = snapshots[idx]
+ detector = None
+ if task == Task.TOP_DOWN:
+ detector = detector_snapshots[detector_idx]
+
+ dir_name = export.get_export_folder_name(mock_loader)
+ filename = export.get_export_filename(mock_loader, snapshot, detector)
+ expected_export = project_dir / "exported-models-pytorch" / dir_name / filename
+ assert expected_export.exists()
+
+ # check that content of the exports are correct
+ exported_data = torch.load(expected_export, weights_only=True)
+ assert isinstance(exported_data, dict)
+ assert "config" in exported_data
+ assert exported_data["config"] == mock_loader.model_cfg
+
+ assert "pose" in exported_data
+ assert exported_data["pose"] == snapshot_data[idx]["model"]
+
+ if task == Task.TOP_DOWN:
+ assert "detector" in exported_data
+ assert exported_data["detector"] == detector_data[detector_idx]["model"]
+
+
+@patch("deeplabcut.pose_estimation_pytorch.apis.export.wipe_paths_from_model_config")
+@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
+def test_export_model_clear_paths(mock_wipe: Mock, project_dir, task: Task):
+ test_data = _get_export_model_data(project_dir, 1, task)
+ mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
+
+ def get_mock_loader(*args, **kwargs):
+ return mock_loader
+
+ with patch(
+ "deeplabcut.pose_estimation_pytorch.apis.export.dlc3_data.DLCLoader",
+ get_mock_loader,
+ ):
+ export.export_model(project_dir / "config.yaml", wipe_paths=True)
+
+ # check that wipe_paths_from_model_config was called
+ assert mock_wipe.call_count == 1
+
+
+@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
+@pytest.mark.parametrize("overwrite", [True, False])
+def test_export_overwrite(project_dir, task: Task, overwrite: bool):
+ test_data = _get_export_model_data(project_dir, 1, task)
+ mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
+ snapshot = snapshots[0]
+ detector = None if task == Task.BOTTOM_UP else detector_snapshots[0]
+
+ def get_mock_loader(*args, **kwargs):
+ return mock_loader
+
+ with patch(
+ "deeplabcut.pose_estimation_pytorch.apis.export.dlc3_data.DLCLoader",
+ get_mock_loader,
+ ):
+ dir_name = export.get_export_folder_name(mock_loader)
+ filename = export.get_export_filename(mock_loader, snapshot, detector)
+ expected_export = project_dir / "exported-models-pytorch" / dir_name / filename
+ expected_export.parent.mkdir(exist_ok=False, parents=True)
+
+ # add existing data
+ assert not expected_export.exists()
+ existing_data = dict()
+ torch.save(existing_data, expected_export)
+
+ # export data
+ export.export_model(project_dir / "config.yaml", overwrite=overwrite)
+
+ exported_data = torch.load(expected_export, weights_only=True)
+
+ if overwrite:
+ assert existing_data != exported_data
+ else:
+ assert existing_data == exported_data
+
+
+@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
+@pytest.mark.parametrize("iteration", [5, 12])
+def test_export_change_iteration(project_dir, task: Task, iteration: int):
+ test_data = _get_export_model_data(
+ project_dir,
+ 1,
+ task,
+ project_iteration=0,
+ )
+ mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
+ snapshot = snapshots[0]
+ detector = None if task == Task.BOTTOM_UP else detector_snapshots[0]
+
+ loader_diff_iter = _get_export_model_data(
+ project_dir, 1, task, project_iteration=iteration
+ )[0]
+
+ def get_mock_loader(config, *args, **kwargs):
+ _loader = copy.deepcopy(mock_loader)
+ if isinstance(config, dict):
+ _loader = copy.deepcopy(mock_loader)
+ _loader.project_cfg = config
+ return _loader
+
+ def read_mock_config(*args, **kwargs):
+ return copy.deepcopy(mock_loader.project_cfg)
+
+ # patch the DLCLoader but also read_config
+ with patch(
+ "deeplabcut.pose_estimation_pytorch.apis.export.dlc3_data.DLCLoader",
+ get_mock_loader,
+ ):
+ with patch(
+ "deeplabcut.pose_estimation_pytorch.apis.export.af.read_config",
+ read_mock_config,
+ ):
+ # check no exports exist yet
+ for loader in [mock_loader, loader_diff_iter]:
+ dir_name = export.get_export_folder_name(loader)
+ filename = export.get_export_filename(loader, snapshot, detector)
+ assert not (
+ project_dir / "exported-models-pytorch" / dir_name / filename
+ ).exists()
+
+ # export data
+ export.export_model(project_dir / "config.yaml", iteration=iteration)
+
+ # check the export exists for the correct iteration
+ for loader, file_should_exist in [
+ (mock_loader, False),
+ (loader_diff_iter, True),
+ ]:
+ dir_name = export.get_export_folder_name(loader)
+ filename = export.get_export_filename(loader, snapshot, detector)
+ expected = project_dir / "exported-models-pytorch" / dir_name / filename
+ expected_exists = expected.exists()
+ assert expected_exists == file_should_exist
diff --git a/tests/pose_estimation_pytorch/apis/test_create_tracking_dataset.py b/tests/pose_estimation_pytorch/apis/test_create_tracking_dataset.py
new file mode 100644
index 0000000000..68b729efbd
--- /dev/null
+++ b/tests/pose_estimation_pytorch/apis/test_create_tracking_dataset.py
@@ -0,0 +1,72 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests method to create the tracking dataset in PyTorch"""
+from pathlib import Path
+
+import torch
+
+import deeplabcut.pose_estimation_pytorch as dlc_torch
+import deeplabcut.pose_estimation_pytorch.apis.tracking_dataset as tracking_dataset
+import deeplabcut.pose_estimation_pytorch.models as models
+
+
+class MockLoader(dlc_torch.Loader):
+ """Mock loader for data"""
+
+ def __init__(self, tmp_folder: Path, bodyparts: list[str] | None = None):
+ if bodyparts is None:
+ bodyparts = ["nose", "left_eye", "right_eye", "tail_base"]
+ self.bodyparts = bodyparts
+
+ model_config_path = tmp_folder / "pytorch_config.yaml"
+ dlc_torch.config.make_pytorch_pose_config(
+ project_config=dlc_torch.config.make_basic_project_config(
+ dataset_path=str(tmp_folder),
+ bodyparts=self.bodyparts,
+ max_individuals=3,
+ ),
+ pose_config_path=tmp_folder / "pytorch_config.yaml",
+ net_type="resnet_50",
+ save=True,
+ )
+ super().__init__(model_config_path)
+
+ def load_data(self, mode: str = "train") -> dict[str, list[dict]]:
+ return {
+ "annotations": [],
+ "categories": [],
+ "images": [],
+ }
+
+ def get_dataset_parameters(self) -> dlc_torch.PoseDatasetParameters:
+ return dlc_torch.PoseDatasetParameters(
+ bodyparts=self.bodyparts,
+ unique_bpts=[],
+ individuals=self.model_cfg["metadata"]["individuals"],
+ )
+
+
+def test_build_feature_extraction_runner(tmp_path_factory):
+ tmp_folder = Path(tmp_path_factory.mktemp("tmp-project"))
+
+ loader = MockLoader(tmp_folder=tmp_folder)
+ model = models.PoseModel.build(loader.model_cfg["model"])
+ snapshot_path = loader.model_folder / "snapshot.pt"
+ torch.save(dict(model=model.state_dict()), snapshot_path)
+ _ = tracking_dataset.build_feature_extraction_runner(
+ loader=loader,
+ snapshot_path=snapshot_path,
+ device="cpu",
+ batch_size=1,
+ )
+
+
+
diff --git a/tests/pose_estimation_pytorch/config/test_config_utils.py b/tests/pose_estimation_pytorch/config/test_config_utils.py
new file mode 100644
index 0000000000..1084e5b940
--- /dev/null
+++ b/tests/pose_estimation_pytorch/config/test_config_utils.py
@@ -0,0 +1,66 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Test util functions for config creation"""
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.config.utils as utils
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ dict(
+ config={},
+ num_bodyparts=None,
+ num_individuals=None,
+ backbone_output_channels=None,
+ output_config={},
+ ),
+ dict(
+ config={
+ "a": "num_bodyparts",
+ "b": ["num_bodyparts // 2", "num_bodyparts // 3"],
+ "c": "num_bodyparts x 2",
+ "d": "num_bodyparts + 2",
+ },
+ num_bodyparts=10,
+ num_individuals=None,
+ backbone_output_channels=None,
+ output_config={
+ "a": 10,
+ "b": [5, 3],
+ "c": 20,
+ "d": 12,
+ },
+ ),
+ dict(
+ config={
+ "a": [{"b": "num_individuals x 3"}],
+ "b": [[{"b": "num_bodyparts x 3"}]],
+ },
+ num_bodyparts=10,
+ num_individuals=1,
+ backbone_output_channels=None,
+ output_config={
+ "a": [{"b": 3}],
+ "b": [[{"b": 30}]],
+ },
+ )
+ ],
+)
+def test_replace_default_values_no_extras(data: dict):
+ output_config = utils.replace_default_values(
+ config=data["config"],
+ num_bodyparts=data["num_bodyparts"],
+ num_individuals=data["num_individuals"],
+ backbone_output_channels=data["backbone_output_channels"],
+ )
+ assert output_config == data["output_config"]
diff --git a/tests/pose_estimation_pytorch/config/test_make_pose_config.py b/tests/pose_estimation_pytorch/config/test_make_pose_config.py
new file mode 100644
index 0000000000..3da2f0f1f4
--- /dev/null
+++ b/tests/pose_estimation_pytorch/config/test_make_pose_config.py
@@ -0,0 +1,493 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the pre-processors"""
+import pytest
+
+import deeplabcut.utils.auxiliaryfunctions as af
+from deeplabcut.core.config import pretty_print
+from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
+ make_basic_project_config,
+ make_pytorch_pose_config,
+)
+from deeplabcut.pose_estimation_pytorch.config.utils import (
+ update_config,
+ update_config_by_dotpath,
+)
+
+
+@pytest.mark.parametrize("bodyparts", [["nose"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize(
+ "net_type", ["resnet_50", "resnet_101", "hrnet_w18", "hrnet_w32", "hrnet_w48"]
+)
+def test_make_single_animal_config(bodyparts: list[str], net_type: str):
+ # Single animal projects can't have unique bodyparts
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=False,
+ identity=False,
+ individuals=[],
+ bodyparts=bodyparts,
+ unique_bodyparts=[],
+ )
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ pretty_print(pytorch_pose_config)
+
+ # check heads are there
+ assert "bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ # check that the bodypart head has locref and heatmaps and the correct output shapes
+ bodypart_head = pytorch_pose_config["model"]["heads"]["bodypart"]
+
+ outputs = [("heatmap_config", len(bodyparts))]
+ if bodypart_head["predictor"]["location_refinement"]:
+ outputs += [("locref_config", 2 * len(bodyparts))]
+
+ for name, output_channels in outputs:
+ head = bodypart_head[name]
+ if "final_conv" in head:
+ actual_output_channels = head["final_conv"]["out_channels"]
+ else:
+ actual_output_channels = head["channels"][-1]
+ assert name in bodypart_head
+ assert actual_output_channels == output_channels
+
+
+@pytest.mark.parametrize("multianimal", [True])
+@pytest.mark.parametrize("individuals", [["single"], ["bugs", "daffy"]])
+@pytest.mark.parametrize("bodyparts", [["nose"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize("identity", [False, True])
+@pytest.mark.parametrize("unique_bodyparts", [[], ["tail"]])
+@pytest.mark.parametrize(
+ "net_type", ["resnet_50", "resnet_101", "hrnet_w18", "hrnet_w32", "hrnet_w48"]
+)
+def test_backbone_plus_paf_config(
+ multianimal: bool,
+ individuals: list[str],
+ bodyparts: list[str],
+ identity: bool,
+ unique_bodyparts: list[str],
+ net_type: str,
+):
+ # Single animal projects can't have unique bodyparts
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=multianimal,
+ identity=identity,
+ individuals=individuals,
+ bodyparts=bodyparts,
+ unique_bodyparts=unique_bodyparts,
+ )
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ pretty_print(pytorch_pose_config)
+
+ graph = [
+ [i, j] for i in range(len(bodyparts)) for j in range(i + 1, len(bodyparts))
+ ]
+ num_limbs = len(graph) * 2
+
+ # check heads are there
+ assert "bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ bodypart_head = pytorch_pose_config["model"]["heads"]["bodypart"]
+
+ # check PAF head
+ assert bodypart_head["type"] == "DLCRNetHead"
+ assert bodypart_head["predictor"]["type"] == "PartAffinityFieldPredictor"
+
+ for name, output_channels in [
+ ("heatmap_config", len(bodyparts)),
+ ("locref_config", len(bodyparts) * 2),
+ ("paf_config", num_limbs),
+ ]:
+ print(name, bodypart_head[name]["channels"])
+ assert name in bodypart_head
+ assert bodypart_head[name]["channels"][-1] == output_channels
+
+ if len(unique_bodyparts) > 0:
+ assert "unique_bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ unique_bodypart_head = pytorch_pose_config["model"]["heads"]["unique_bodypart"]
+ for name, output_channels in [
+ ("heatmap_config", len(unique_bodyparts)),
+ ("locref_config", 2 * len(unique_bodyparts)),
+ ]:
+ assert name in unique_bodypart_head
+ assert unique_bodypart_head[name]["channels"][-1] == output_channels
+ assert unique_bodypart_head["target_generator"]["heatmap_mode"] == "KEYPOINT"
+
+ if identity:
+ assert "identity" in pytorch_pose_config["model"]["heads"].keys()
+ id_head = pytorch_pose_config["model"]["heads"]["identity"]
+ assert "heatmap_config" in id_head
+ assert id_head["heatmap_config"]["channels"][-1] == len(individuals)
+ assert "locref_config" not in id_head
+ assert id_head["target_generator"]["heatmap_mode"] == "INDIVIDUAL"
+
+
+@pytest.mark.parametrize(
+ "detector",
+ [
+ (None, "SSDLite"),
+ ("ssdlite", "SSDLite"),
+ ("fasterrcnn_mobilenet_v3_large_fpn", "FasterRCNN"),
+ ("fasterrcnn_resnet50_fpn_v2", "FasterRCNN"),
+ ],
+)
+@pytest.mark.parametrize("individuals", [["single"], ["bugs", "daffy"]])
+@pytest.mark.parametrize("bodyparts", [["nose"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize(
+ "net_type", ["resnet_50", "resnet_101", "hrnet_w18", "hrnet_w32", "hrnet_w48"]
+)
+def test_top_down_config(
+ detector: tuple[str, str],
+ individuals: list[str],
+ bodyparts: list[str],
+ net_type: str,
+):
+ # Single animal projects can't have unique bodyparts
+ detector_type, expected_detector_type = detector
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=True,
+ identity=False,
+ individuals=individuals,
+ bodyparts=bodyparts,
+ unique_bodyparts=[],
+ )
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ top_down=True,
+ detector_type=detector_type,
+ )
+ pretty_print(pytorch_pose_config)
+
+ # check no collate function
+ collate = pytorch_pose_config["data"]["train"].get("collate")
+ print(f"Collate: {collate}")
+ assert not collate
+
+ # check heads are there
+ assert "bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ bodypart_head = pytorch_pose_config["model"]["heads"]["bodypart"]
+
+ # check detector is there
+ assert "detector" in pytorch_pose_config.keys()
+ assert pytorch_pose_config["detector"]["model"]["type"] == expected_detector_type
+
+ for name, output_channels in [
+ ("heatmap_config", len(bodyparts)),
+ ]:
+ print(name, bodypart_head[name]["channels"])
+ assert name in bodypart_head
+ assert bodypart_head[name]["final_conv"]["out_channels"] == output_channels
+
+
+@pytest.mark.parametrize("multianimal", [True])
+@pytest.mark.parametrize("individuals", [["single"], ["bugs", "daffy"]])
+@pytest.mark.parametrize("bodyparts", [["nose"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize("identity", [False, True])
+@pytest.mark.parametrize("unique_bodyparts", [[], ["tail"]])
+@pytest.mark.parametrize("net_type", ["dekr_w18", "dekr_w32", "dekr_w48"])
+def test_make_dekr_config(
+ multianimal: bool,
+ individuals: list[str],
+ bodyparts: list[str],
+ identity: bool,
+ unique_bodyparts: list[str],
+ net_type: str,
+):
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=multianimal,
+ identity=identity,
+ individuals=individuals,
+ bodyparts=bodyparts,
+ unique_bodyparts=unique_bodyparts,
+ )
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ pretty_print(pytorch_pose_config)
+
+ # check heads are there
+ assert "bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ bodypart_head = pytorch_pose_config["model"]["heads"]["bodypart"]
+ for name, output_channels in [
+ ("heatmap_config", len(bodyparts) + 1),
+ ("offset_config", len(bodyparts)),
+ ]:
+ print(name, bodypart_head[name]["channels"])
+ assert name in bodypart_head
+ assert bodypart_head[name]["channels"][-1] == output_channels
+
+ if len(unique_bodyparts) > 0:
+ assert "unique_bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ unique_bodypart_head = pytorch_pose_config["model"]["heads"]["unique_bodypart"]
+ for name, output_channels in [
+ ("heatmap_config", len(unique_bodyparts)),
+ ("locref_config", 2 * len(unique_bodyparts)),
+ ]:
+ assert name in unique_bodypart_head
+ assert unique_bodypart_head[name]["channels"][-1] == output_channels
+ assert unique_bodypart_head["target_generator"]["heatmap_mode"] == "KEYPOINT"
+
+ if identity:
+ assert "identity" in pytorch_pose_config["model"]["heads"].keys()
+ id_head = pytorch_pose_config["model"]["heads"]["identity"]
+ assert "heatmap_config" in id_head
+ assert id_head["heatmap_config"]["channels"][-1] == len(individuals)
+ assert "locref_config" not in id_head
+ assert id_head["target_generator"]["heatmap_mode"] == "INDIVIDUAL"
+
+
+@pytest.mark.parametrize("multianimal", [True])
+@pytest.mark.parametrize("individuals", [["single"], ["bugs", "daffy"]])
+@pytest.mark.parametrize("bodyparts", [["nose", "ears"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize("identity", [False, True])
+@pytest.mark.parametrize("unique_bodyparts", [[], ["tail"]])
+@pytest.mark.parametrize("net_type", ["dlcrnet_stride16_ms5", "dlcrnet_stride32_ms5"])
+def test_make_dlcrnet_config(
+ multianimal: bool,
+ individuals: list[str],
+ bodyparts: list[str],
+ identity: bool,
+ unique_bodyparts: list[str],
+ net_type: str,
+):
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=multianimal,
+ identity=identity,
+ individuals=individuals,
+ bodyparts=bodyparts,
+ unique_bodyparts=unique_bodyparts,
+ )
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ pretty_print(pytorch_pose_config)
+ paf_graph = [
+ [i, j] for i in range(len(bodyparts)) for j in range(i + 1, len(bodyparts))
+ ]
+ num_limbs = len(paf_graph)
+
+ # check heads are there
+ assert "bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ bodypart_head = pytorch_pose_config["model"]["heads"]["bodypart"]
+ for name, output_channels in [
+ ("heatmap_config", len(bodyparts)),
+ ("locref_config", 2 * len(bodyparts)),
+ ("paf_config", 2 * num_limbs),
+ ]:
+ print(name, bodypart_head[name]["channels"])
+ assert name in bodypart_head
+ assert bodypart_head[name]["channels"][-1] == output_channels
+
+ if len(unique_bodyparts) > 0:
+ assert "unique_bodypart" in pytorch_pose_config["model"]["heads"].keys()
+ unique_bodypart_head = pytorch_pose_config["model"]["heads"]["unique_bodypart"]
+ for name, output_channels in [
+ ("heatmap_config", len(unique_bodyparts)),
+ ("locref_config", 2 * len(unique_bodyparts)),
+ ]:
+ assert name in unique_bodypart_head
+ assert unique_bodypart_head[name]["channels"][-1] == output_channels
+ assert unique_bodypart_head["target_generator"]["heatmap_mode"] == "KEYPOINT"
+
+ if identity:
+ assert "identity" in pytorch_pose_config["model"]["heads"].keys()
+ id_head = pytorch_pose_config["model"]["heads"]["identity"]
+ assert "heatmap_config" in id_head
+ assert id_head["heatmap_config"]["channels"][-1] == len(individuals)
+ assert "locref_config" not in id_head
+ assert id_head["target_generator"]["heatmap_mode"] == "INDIVIDUAL"
+
+
+@pytest.mark.parametrize("individuals", [["single"], ["bugs", "daffy"]])
+@pytest.mark.parametrize("bodyparts", [["nose", "eyes"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize("identity", [False, True])
+@pytest.mark.parametrize("unique_bodyparts", [[], ["tail"]])
+@pytest.mark.parametrize("net_type", ["animaltokenpose_base"])
+def test_make_tokenpose_config(
+ individuals: list[str],
+ bodyparts: list[str],
+ identity: bool,
+ unique_bodyparts: list[str],
+ net_type: str,
+):
+ project_config = _make_project_config(
+ project_path="my/little/project",
+ multianimal=True,
+ identity=identity,
+ individuals=individuals,
+ bodyparts=bodyparts,
+ unique_bodyparts=unique_bodyparts,
+ )
+
+ if identity or len(unique_bodyparts) > 0:
+ with pytest.raises(ValueError) as err_info:
+ # Not yet implemented!
+ _ = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ else:
+ pytorch_pose_config = make_pytorch_pose_config(
+ project_config,
+ "pytorch_config.yaml",
+ net_type=net_type,
+ )
+ pretty_print(pytorch_pose_config)
+
+ # check no collate function
+ collate = pytorch_pose_config["data"]["train"].get("collate")
+ print(f"Collate: {collate}")
+ assert not collate
+
+ # check detector is there
+ assert "detector" in pytorch_pose_config
+ assert "data" in pytorch_pose_config["detector"]
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "config": {"a": 0, "b": 0},
+ "updates": {"b": 1},
+ "expected_result": {"a": 0, "b": 1},
+ },
+ {
+ "config": {"a": 0, "b": {"i0": 1, "i1": 2}},
+ "updates": {"b": 1},
+ "expected_result": {"a": 0, "b": 1},
+ },
+ {
+ "config": {"a": 0, "b": {"i0": 1, "i1": 2}},
+ "updates": {"b": {"i0": [1, 2, 3]}},
+ "expected_result": {"a": 0, "b": {"i0": [1, 2, 3], "i1": 2}},
+ },
+ {
+ "config": {"detector": {"batch_size": 1, "epochs": 10, "save_epochs": 5}},
+ "updates": {
+ "batch_size": 1,
+ "detector": {"batch_size": 8, "save_epochs": 1},
+ },
+ "expected_result": {
+ "batch_size": 1,
+ "detector": {"batch_size": 8, "epochs": 10, "save_epochs": 1},
+ },
+ },
+ ],
+)
+def test_update_config(data: dict):
+ result = update_config(config=data["config"], updates=data["updates"])
+ print("\nResult")
+ pretty_print(result)
+ assert result == data["expected_result"]
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "config": {"a": 0, "b": 0},
+ "updates": {"b": 1},
+ "expected_result": {"a": 0, "b": 1},
+ },
+ {
+ "config": {"a": 0, "b": {"i0": 1, "i1": 2}},
+ "updates": {"b": 1},
+ "expected_result": {"a": 0, "b": 1},
+ },
+ {
+ "config": {"a": 0, "b": {"i0": 1, "i1": 2}},
+ "updates": {"b.i0": [1, 2, 3]},
+ "expected_result": {"a": 0, "b": {"i0": [1, 2, 3], "i1": 2}},
+ },
+ {
+ "config": {"detector": {"batch_size": 1, "epochs": 10, "save_epochs": 5}},
+ "updates": {
+ "batch_size": 1,
+ "detector.batch_size": 8,
+ "detector.save_epochs": 1,
+ },
+ "expected_result": {
+ "batch_size": 1,
+ "detector": {"batch_size": 8, "epochs": 10, "save_epochs": 1},
+ },
+ },
+ ],
+)
+def test_update_config_by_dotpath(data: dict):
+ result = update_config_by_dotpath(config=data["config"], updates=data["updates"])
+ print("\nResult")
+ pretty_print(result)
+ assert result == data["expected_result"]
+
+
+def _make_project_config(
+ project_path: str,
+ multianimal: bool,
+ identity: bool,
+ individuals: list[str],
+ bodyparts: list[str],
+ unique_bodyparts: list[str],
+) -> dict:
+ project_config = {
+ "project_path": project_path,
+ "multianimalproject": multianimal,
+ "identity": identity,
+ "uniquebodyparts": unique_bodyparts,
+ }
+
+ if multianimal:
+ project_config["multianimalbodyparts"] = bodyparts
+ project_config["bodyparts"] = "MULTI!"
+ project_config["individuals"] = individuals
+ else:
+ project_config["bodyparts"] = bodyparts
+
+ return project_config
+
+
+@pytest.mark.parametrize("bodyparts", [["nose"], ["nose", "ear", "eye"]])
+@pytest.mark.parametrize("max_idv", [1, 12, 20])
+@pytest.mark.parametrize("multi", [True, False])
+def test_make_basic_project_config(bodyparts: list[str], max_idv: int, multi: bool):
+ if not multi and max_idv > 1:
+ return
+
+ project_config = make_basic_project_config(
+ dataset_path="path/dataset",
+ bodyparts=bodyparts,
+ max_individuals=max_idv,
+ multi_animal=multi,
+ )
+
+ bpts = af.get_bodyparts(project_config)
+ assert bodyparts == bpts
+
+ individuals = project_config["individuals"]
+ assert len(individuals) == max_idv
+ assert len(set(individuals)) == max_idv
diff --git a/tests/pose_estimation_pytorch/data/test_postprocessor.py b/tests/pose_estimation_pytorch/data/test_postprocessor.py
new file mode 100644
index 0000000000..5ed97776c6
--- /dev/null
+++ b/tests/pose_estimation_pytorch/data/test_postprocessor.py
@@ -0,0 +1,305 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the pre-processors"""
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.data.postprocessor import (
+ PredictKeypointIdentities,
+ PrepareBackboneFeatures,
+ RescaleAndOffset,
+ TrimOutputs,
+)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "predictions": [[[0, 0, 0.95], [20, 30, 0.5]]],
+ "offsets": [(0, 0)],
+ "scales": [(1, 1)],
+ "rescaled": [[[0, 0, 0.95], [20, 30, 0.5]]],
+ },
+ {
+ "predictions": [
+ [[0, 0, 0.12], [1000, 0, 0.5]], # individual 1
+ [[18, 2, 0.24], [0, 1000, 0.6]], # individual 2
+ ],
+ "offsets": [(0, 0), (0, 0)],
+ "scales": [(1, 1), (0.5, 1.0)],
+ "rescaled": [
+ [[0, 0, 0.12], [1000, 0, 0.5]], # individual 1
+ [[9, 2, 0.24], [0, 1000, 0.6]], # individual 2
+ ],
+ },
+ {
+ "predictions": [
+ [[0, 0, 0.95], [20, 30, 0.5]], # individual 1
+ [[110, 5, 0.95], [60, 1200, 0.5]], # individual 2
+ ],
+ "offsets": [(12, 5), (27, 10)],
+ "scales": [(0.5, 0.5), (0.2, 0.2)],
+ "rescaled": [
+ [[12, 5, 0.95], [22, 20, 0.5]], # individual 1
+ [[49, 11, 0.95], [39, 250, 0.5]], # individual 2
+ ],
+ },
+ ],
+)
+def test_rescale_topdown(data):
+ """expects x_processed = x * scale + offset"""
+ postprocessor = RescaleAndOffset(
+ keys_to_rescale=["bodyparts"],
+ mode=RescaleAndOffset.Mode.KEYPOINT_TD,
+ )
+ context = {"scales": data["scales"], "offsets": data["offsets"]}
+ predictions = {"bodyparts": np.array(data["predictions"])}
+ predictions, context = postprocessor(predictions, context=context)
+ print(predictions["bodyparts"].tolist())
+ print(data["rescaled"])
+ np.testing.assert_array_equal(predictions["bodyparts"], np.array(data["rescaled"]))
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "bboxes": [[0, 0, 0, 0], [1, 1, 1, 1]],
+ "bbox_scores": [0, 0],
+ "max_individuals": {"bboxes": 1, "bbox_scores": 1},
+ },
+ {
+ "bboxes": [[0, 0, 0, 0], [1, 1, 1, 1]],
+ "bbox_scores": [0, 0],
+ "max_individuals": {"bboxes": 2, "bbox_scores": 2},
+ },
+ ],
+)
+def test_trim_outputs(data):
+ """expects x_processed = x * scale + offset"""
+ postprocessor = TrimOutputs(max_individuals=data["max_individuals"])
+ context = {}
+ predictions = {"bboxes": np.array(data["bboxes"]), "bbox_scores": np.array(data["bbox_scores"])}
+ predictions, context = postprocessor(predictions, context=context)
+ print(predictions["bboxes"].tolist())
+ print(predictions["bbox_scores"].tolist())
+ assert len(predictions["bboxes"]) == data["max_individuals"]["bboxes"]
+ assert len(predictions["bbox_scores"]) == data["max_individuals"]["bbox_scores"]
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "predictions": [[[0, 0, 0.95], [20, 30, 0.5]]],
+ "offsets": (0, 0),
+ "scales": (1, 1),
+ "rescaled": [[[0, 0, 0.95], [20, 30, 0.5]]],
+ },
+ {
+ "predictions": [
+ [[0, 0, 0.12], [10, 0, 0.5]], # individual 1
+ [[1000, 500, 0.24], [50, 250, 0.6]], # individual 2
+ ],
+ "offsets": (5, 7),
+ "scales": (0.2, 0.5),
+ "rescaled": [
+ [[5, 7, 0.12], [7, 7, 0.5]], # individual 1
+ [[205, 257, 0.24], [15, 132, 0.6]], # individual 2
+ ],
+ },
+ ],
+)
+def test_rescale_bottom_up(data):
+ """expects x_processed = x * scale + offset"""
+ postprocessor = RescaleAndOffset(
+ keys_to_rescale=["bodyparts"],
+ mode=RescaleAndOffset.Mode.KEYPOINT,
+ )
+ context = {"scales": data["scales"], "offsets": data["offsets"]}
+ predictions = {"bodyparts": np.array(data["predictions"])}
+ predictions, context = postprocessor(predictions, context=context)
+ print(predictions["bodyparts"].tolist())
+ print(data["rescaled"])
+ np.testing.assert_array_equal(predictions["bodyparts"], np.array(data["rescaled"]))
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "bboxes": [[222.0, 562.0, 721.0, 637.0]],
+ "offsets": (0, 0),
+ "scales": (1, 1),
+ "rescaled": [[222.0, 562.0, 721.0, 637.0]],
+ },
+ {
+ "bboxes": [[386.71875, 219.53125, 281.640625, 248.828125]],
+ "offsets": (-768, 0),
+ "scales": (2.56, 2.56),
+ "rescaled": [[222.0, 562.0, 721.0, 637.0]],
+ },
+ {
+ "bboxes": [
+ [0, 0, 100, 100],
+ [5, 10, 100, 100],
+ [5, 10, 10, 20],
+ ],
+ "offsets": (3, 7),
+ "scales": (2, 0.5),
+ "rescaled": [
+ [3, 7, 200, 50],
+ [13, 12, 200, 50],
+ [13, 12, 20, 10],
+ ],
+ },
+ ],
+)
+def test_rescale_detector(data):
+ """expects x_processed = x * scale + offset"""
+ postprocessor = RescaleAndOffset(
+ keys_to_rescale=["bboxes"],
+ mode=RescaleAndOffset.Mode.BBOX_XYWH,
+ )
+ context = {"scales": data["scales"], "offsets": data["offsets"]}
+ predictions = {"bboxes": np.array(data["bboxes"])}
+ predictions, context = postprocessor(predictions, context=context)
+ print(predictions["bboxes"].tolist())
+ print(data["rescaled"])
+ np.testing.assert_array_equal(predictions["bboxes"], np.array(data["rescaled"]))
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "bodyparts": [
+ [[3.1, 1, 0.8], [1, 0, 0.9]], # assembly 1 (x, y, score)
+ [[2.2, 1.6, 0.5], [3, 3, 0.4]], # assembly 2 (x, y, score)
+ ],
+ "id_heatmap": [ # id1, id2 score for each pixel
+ [[0.1, 0.1], [0.2, 0.1], [0.3, 0.1], [0.4, 0.1]],
+ [[0.1, 0.2], [0.2, 0.2], [0.3, 0.2], [0.4, 0.2]],
+ [[0.1, 0.3], [0.2, 0.3], [0.3, 0.3], [0.4, 0.3]],
+ [[0.1, 0.4], [0.2, 0.4], [0.3, 0.4], [0.4, 0.4]],
+ ],
+ "id_scores": [ # id1, id2 score for each bodypart
+ [[0.4, 0.2], [0.2, 0.1]], # assembly 1 (id_1 proba, id_2 proba)
+ [[0.3, 0.3], [0.4, 0.4]], # assembly 2 (id_1 proba, id_2 proba)
+ ],
+ },
+ ],
+)
+def test_assign_id_scores(data):
+ p = PredictKeypointIdentities(
+ identity_key="keypoint_identity",
+ identity_map_key="identity_map",
+ pose_key="bodyparts",
+ keep_id_maps=True,
+ )
+ bodyparts = np.array(data["bodyparts"])
+ id_heatmap = np.array(data["id_heatmap"])
+ expected_ids = np.array(data["id_scores"])
+ print()
+ print(bodyparts.shape)
+ print(id_heatmap.shape)
+ print(expected_ids.shape)
+ predictions_in = {"bodyparts": bodyparts, "identity_map": id_heatmap}
+ predictions, _ = p(predictions_in, {})
+ np.testing.assert_array_equal(
+ predictions["keypoint_identity"],
+ expected_ids,
+ )
+
+
+def test_prepare_backbone_features():
+ p = PrepareBackboneFeatures(top_down=False)
+
+ img_w, img_h = 256, 128
+ features = np.zeros((1, img_h, img_w))
+
+ features[0, 15, 10] = 1
+ features[0, 25, 20] = 2
+ features[0, 35, 30] = 3
+
+ pose = np.array([
+ [
+ [10.1, 15.1, 0.95],
+ [20.1, 25.1, 0.95],
+ [29.9, 34.9, 0.95],
+ ],
+ ])
+
+ predictions = [dict(backbone=dict(features=features), bodypart=dict(poses=pose))]
+ context = dict(image_size=(img_w, img_h))
+ predictions_out, context_out = p(predictions, context)
+
+ assert len(predictions_out) == 1
+ assert len(context_out) == 1
+ preds = predictions_out[0]
+
+ assert "backbone" in preds
+ assert "bodypart_features" in preds["backbone"]
+ bodypart_features = preds["backbone"]["bodypart_features"]
+ print(f"Bodypart features: {bodypart_features.shape}")
+ print(bodypart_features)
+ assert bodypart_features.shape == (1, 3, 1)
+ assert bodypart_features.reshape(-1).tolist() == [1, 2, 3]
+
+
+def test_prepare_top_down_backbone_features():
+ p = PrepareBackboneFeatures(top_down=True)
+
+ img_w, img_h = 256, 256
+
+ features = np.zeros((2, 1, img_h, img_w))
+ features[0, 0, 15, 10] = 1
+ features[0, 0, 25, 20] = 2
+ features[0, 0, 35, 30] = 3
+ features[1, 0, 95, 10] = 11
+ features[1, 0, 85, 20] = 12
+ features[1, 0, 75, 30] = 13
+
+ pose_idv0 = np.array([
+ [
+ [10.1, 15.1, 0.95],
+ [20.1, 25.1, 0.95],
+ [29.9, 34.9, 0.95],
+ ],
+ ])
+ pose_idv1 = np.array(
+ [
+ [
+ [10.1, 95.1, 0.95],
+ [20.1, 85.1, 0.95],
+ [29.9, 74.9, 0.95],
+ ],
+ ]
+ )
+
+ predictions = [
+ dict(backbone=dict(features=features[0]), bodypart=dict(poses=pose_idv0)),
+ dict(backbone=dict(features=features[1]), bodypart=dict(poses=pose_idv1)),
+ ]
+ context = dict(top_down_crop_size=(img_w, img_h))
+ predictions_out, context_out = p(predictions, context)
+
+ assert len(predictions_out) == 2
+ assert len(context_out) == 1
+ for preds, expected in zip(predictions_out, [[1, 2, 3], [11, 12, 13]]):
+ assert "backbone" in preds
+ assert "bodypart_features" in preds["backbone"]
+ bodypart_features = preds["backbone"]["bodypart_features"]
+ print(f"Bodypart features: {bodypart_features.shape}")
+ print(bodypart_features)
+ assert bodypart_features.shape == (1, 3, 1)
+ assert bodypart_features.reshape(-1).tolist() == expected
diff --git a/tests/pose_estimation_pytorch/data/test_preprocessor.py b/tests/pose_estimation_pytorch/data/test_preprocessor.py
new file mode 100644
index 0000000000..8c32edbefc
--- /dev/null
+++ b/tests/pose_estimation_pytorch/data/test_preprocessor.py
@@ -0,0 +1,61 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the pre-processors"""
+import albumentations as A
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.data.transforms import build_resize_transforms
+from deeplabcut.pose_estimation_pytorch.data.preprocessor import AugmentImage
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "image_shape": (2, 4, 4),
+ "resize_transform": {"height": 5, "width": 4, "keep_ratio": True},
+ "output_shape": (2, 4, 4),
+ "padded_shape": (5, 4, 4), # single offset as not a batch
+ "output_context": {"offsets": (0, 0), "scales": (1, 1)}
+ },
+ {
+ "image_shape": (1, 2, 4, 4), # as batch
+ "resize_transform": {"height": 10, "width": 4, "keep_ratio": True},
+ "output_shape": (1, 2, 4, 4),
+ "padded_shape": (1, 10, 4, 4),
+ "output_context": {"offsets": [(0, 0)], "scales": [(1, 1)]}
+ },
+ {
+ "image_shape": (2, 4, 3),
+ "resize_transform": {"height": 10, "width": 8, "keep_ratio": True},
+ "output_shape": (4, 8, 3),
+ "padded_shape": (10, 8, 3),
+ "output_context": {"offsets": (0, 0), "scales": (0.5, 0.5)}
+ },
+ ],
+)
+def test_augment_image_rescaling(data):
+ resize_transform = build_resize_transforms(data["resize_transform"])
+ transform = A.Compose(
+ resize_transform,
+ keypoint_params=A.KeypointParams("xy", remove_invisible=False),
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ preprocessor = AugmentImage(transform)
+ img = np.ones(data["image_shape"])
+ transformed_image, context = preprocessor(img, context={})
+ print()
+ print(transformed_image[:, :, 0]) # first channel
+ print(context)
+ assert np.sum(transformed_image) == np.sum(np.ones(data["output_shape"]))
+ assert context == data["output_context"]
+ assert transformed_image.shape == data["padded_shape"]
diff --git a/tests/pose_estimation_pytorch/data/test_transforms.py b/tests/pose_estimation_pytorch/data/test_transforms.py
new file mode 100644
index 0000000000..482120bdfa
--- /dev/null
+++ b/tests/pose_estimation_pytorch/data/test_transforms.py
@@ -0,0 +1,292 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the custom transforms"""
+import random
+
+import albumentations as A
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.data import transforms
+
+
+@pytest.mark.parametrize(
+ "height, width, image_shapes",
+ [
+ (200, 200, [(300, 300, 3), (1000, 1000, 3), (1024, 1024, 1)]),
+ (512, 512, [(1024, 1024, 3), (128, 128, 4), (300, 300, 1)]),
+ (1024, 512, [(600, 300, 3), (4096, 2048, 3), (50, 25, 1)]),
+ (800, 1300, [(80, 130, 3), (1600, 2600, 4), (1200, 1950, 1)]),
+ ],
+)
+def test_dlc_resize_pad_good_aspect_ratio(height, width, image_shapes):
+ aug = transforms.KeepAspectRatioResize(width=width, height=height, mode="pad")
+ for image_shape in image_shapes:
+ fake_image = np.zeros(image_shape)
+ transformed = aug(image=fake_image, keypoints=[])
+ assert transformed["image"].shape[:2] == (height, width)
+ assert transformed["image"].shape[2] == fake_image.shape[2]
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "height": 200,
+ "width": 200,
+ "in_shapes": [(100, 50, 3), (50, 400, 3)],
+ "out_shapes": [(200, 100, 3), (25, 200, 3)],
+ },
+ {
+ "height": 128,
+ "width": 256,
+ "in_shapes": [(100, 100, 3), (512, 256, 3)],
+ "out_shapes": [(128, 128, 3), (128, 64, 3)],
+ },
+ ],
+)
+def test_dlc_resize_pad_bad_aspect_ratio(data):
+ aug = transforms.KeepAspectRatioResize(width=data["width"], height=data["height"], mode="pad")
+ for in_shape, out_shape in zip(data["in_shapes"], data["out_shapes"]):
+ fake_image = np.zeros(in_shape)
+ transformed = aug(image=fake_image, keypoints=[])
+ assert transformed["image"].shape == out_shape
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "height": 200,
+ "width": 200,
+ "in_shape": (100, 50, 3),
+ "out_shape": (200, 100, 3),
+ "in_keypoints": [(50.0, 50.0), (25.0, 10.0)],
+ "out_keypoints": [(100.0, 100.0), (50.0, 20.0)],
+ },
+ {
+ "height": 512,
+ "width": 256,
+ "in_shape": (1024, 1024, 3),
+ "out_shape": (256, 256, 3),
+ "in_keypoints": [(512.0, 512.0), (100.0, 10.0)],
+ "out_keypoints": [(128.0, 128.0), (25.0, 2.5)],
+ },
+ ],
+)
+def test_dlc_resize_pad_bad_aspect_ratio_with_keypoints(data):
+ aug = transforms.KeepAspectRatioResize(width=data["width"], height=data["height"], mode="pad")
+ transform = A.Compose(
+ [aug],
+ keypoint_params=A.KeypointParams("xy", remove_invisible=False),
+ )
+ fake_image = np.zeros(data["in_shape"])
+ transformed = transform(image=fake_image, keypoints=data["in_keypoints"])
+ assert transformed["image"].shape == data["out_shape"]
+ assert transformed["keypoints"] == data["out_keypoints"]
+
+
+def test_coarse_dropout():
+ aug = transforms.CoarseDropout(
+ max_holes=10,
+ max_height=0.05,
+ min_height=0.01,
+ max_width=0.05,
+ min_width=0.01,
+ p=0.5,
+ )
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=10.0,
+ shift_prob=0.0,
+ scale_factor=[0.1, 2.0],
+ scale_prob=0.0,
+ ),
+ },
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.0,
+ shift_prob=1.0,
+ scale_factor=[1.0, 1.0],
+ scale_prob=1.0,
+ sampling="uniform", # truncnorm throws an error if delta is 0
+ ),
+ },
+ ],
+)
+def test_random_bbox_transform_does_not_modify_with_base_config(data: dict) -> None:
+ _set_random_seed()
+ h, w, c = data["image_shape"]
+
+ # generate 100 bboxes
+ bboxes = _gen_random_bboxes(np.random.default_rng(seed=0), 100, w, h)
+
+ t = A.Compose(
+ [transforms.RandomBBoxTransform(**data["transform_config"])],
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ output = t(
+ image=np.zeros((h, w, c)), bboxes=bboxes, bbox_labels=np.zeros(len(bboxes)),
+ )
+ print("Output bounding boxes")
+ for out_bbox in output["bboxes"]:
+ print(out_bbox)
+ print()
+ bboxes_out = np.asarray(output["bboxes"])
+ print("bboxes")
+ print(bboxes_out)
+ print()
+ np.testing.assert_array_almost_equal(bboxes, bboxes_out)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.0,
+ shift_prob=0.0,
+ scale_factor=[0.25, 0.5],
+ scale_prob=1.0,
+ ),
+ },
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.0,
+ shift_prob=0.0,
+ scale_factor=[1.0, 1.5],
+ scale_prob=1.0,
+ ),
+ },
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.0,
+ shift_prob=0.0,
+ scale_factor=[0.5, 1.25],
+ scale_prob=1.0,
+ ),
+ },
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.0,
+ shift_prob=0.0,
+ scale_factor=[0.5, 1.5],
+ scale_prob=0.5,
+ ),
+ },
+ ],
+)
+def test_random_bbox_transform_scale(data: dict) -> None:
+ _set_random_seed()
+ h, w, c = data["image_shape"]
+
+ # generate 100 bboxes
+ bboxes = _gen_random_bboxes(np.random.default_rng(seed=0), 100, w, h)
+
+ t = A.Compose(
+ [transforms.RandomBBoxTransform(**data["transform_config"])],
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ output = t(
+ image=np.zeros((h, w, c)), bboxes=bboxes, bbox_labels=np.zeros(len(bboxes)),
+ )
+ print("Output bounding boxes")
+ for out_bbox in output["bboxes"]:
+ print(out_bbox)
+ print()
+
+ bboxes_out = np.asarray(output["bboxes"])
+ scale_low, scale_high = data["transform_config"]["scale_factor"]
+ for bbox_in_wh, bbox_out_wh in zip(bboxes[:, 2:], bboxes_out[:, 2:]):
+ print("bbox_in_wh", bbox_in_wh)
+ w, h = bbox_in_wh[0].item(), bbox_in_wh[1].item()
+ w_low, w_high = w * scale_low, w * scale_high
+ h_low, h_high = h * scale_low, h * scale_high
+ print("(w, w_low, w_high)", w, w_low, w_high)
+ print("(h, h_low, h_high)", h, h_low, h_high)
+ assert w_low <= bbox_out_wh[0].item() <= w_high
+ assert h_low <= bbox_out_wh[1].item() <= h_high
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "image_shape": [480, 640, 3],
+ "transform_config": dict(
+ shift_factor=0.1,
+ shift_prob=1.0,
+ scale_factor=[1.0, 1.0],
+ scale_prob=0.0,
+ ),
+ },
+ ],
+)
+def test_random_bbox_transform_shift(data: dict) -> None:
+ _set_random_seed()
+ h, w, c = data["image_shape"]
+
+ # generate 100 bboxes
+ bboxes = _gen_random_bboxes(np.random.default_rng(seed=0), 100, w, h)
+
+ t = A.Compose(
+ [transforms.RandomBBoxTransform(**data["transform_config"])],
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ output = t(
+ image=np.zeros((h, w, c)), bboxes=bboxes, bbox_labels=np.zeros(len(bboxes)),
+ )
+ print("Output bounding boxes")
+ for out_bbox in output["bboxes"]:
+ print(out_bbox)
+ print()
+
+ bboxes_out = np.asarray(output["bboxes"])
+ shift = data["transform_config"]["shift_factor"]
+ for bbox_in, bbox_out in zip(bboxes, bboxes_out):
+ print("bbox_in", bbox_in)
+ x, y, w, h = bbox_in
+ x_out, y_out, w_out, h_out = bbox_out
+ max_shift_x, max_shift_y = w * shift, h * shift
+ assert x - max_shift_x <= x_out <= x + max_shift_x
+ assert y - max_shift_y <= y_out <= y + max_shift_y
+
+
+def _set_random_seed():
+ np.random.seed(0)
+ random.seed(0)
+
+
+def _gen_random_bboxes(
+ gen: np.random.Generator, num_bboxes: int, w: int, h: int,
+) -> np.ndarray:
+ image_wh = np.array([w, h])
+ bboxes = np.zeros((num_bboxes, 4))
+ # sample x, y in the images
+ bboxes[:, :2] = image_wh * gen.random((num_bboxes, 2))
+ # sample w, h with the space remaining
+ bboxes[:, 2:] = (image_wh - bboxes[:, :2]) * gen.random((num_bboxes, 2))
+
+ print()
+ print("Input bounding boxes")
+ print(bboxes)
+ return bboxes
diff --git a/tests/pose_estimation_pytorch/data/test_utils.py b/tests/pose_estimation_pytorch/data/test_utils.py
new file mode 100644
index 0000000000..3b24c4ba12
--- /dev/null
+++ b/tests/pose_estimation_pytorch/data/test_utils.py
@@ -0,0 +1,96 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests data utils"""
+import numpy as np
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.data.utils as utils
+
+
+@pytest.mark.parametrize(
+ "keypoints, expected_bboxes, params",
+ [
+ (
+ [[0, 0, 2], [10, 5, 2]],
+ [0, 0, 10, 5],
+ dict(image_w=1024, image_h=1024, margin=0),
+ ),
+ (
+ [[-1, -1, 2], [3, 4, 2]],
+ [0, 0, 3, 4],
+ dict(image_w=1024, image_h=1024, margin=0),
+ ),
+ (
+ [[0, 0, 2], [10, 5, 2]],
+ [0, 0, 5, 3],
+ dict(image_w=5, image_h=3, margin=0),
+ ),
+ (
+ [[0, 0, 2], [10, 5, 2]],
+ [0, 0, 5, 3],
+ dict(image_w=5, image_h=3, margin=10),
+ ),
+ (
+ [[[0, 0, 2], [10, 5, 2]]],
+ [[0, 0, 10, 5]],
+ dict(image_w=1024, image_h=1024, margin=0),
+ ),
+ (
+ [
+ [[4, 1, 2], [10, 5, 2], [3, 12, 0]],
+ [[7, 3, 2], [2, 0, -1], [1, 12, 2]],
+ ],
+ [
+ [4, 1, 6, 4],
+ [1, 3, 6, 9],
+ ],
+ dict(image_w=1024, image_h=1024, margin=0),
+ ),
+ (
+ [
+ [[4, 1, 2], [10, 5, 2], [3, 12, 0]],
+ [[7, 3, 2], [2, 0, -1], [1, 12, 2]],
+ ],
+ [
+ [2, 0, 10, 7],
+ [0, 1, 9, 13],
+ ],
+ dict(image_w=1024, image_h=1024, margin=2),
+ ),
+ (
+ [
+ [[4, 1, 2], [10, 5, 2], [3, 12, 0]],
+ [[7, 3, 2], [2, 0, -1], [1, 12, 2]],
+ ],
+ [
+ [2, 0, 8, 7],
+ [0, 1, 9, 9],
+ ],
+ dict(image_w=10, image_h=10, margin=2),
+ ),
+ (
+ [
+ [[4, 1, 2], [10, 5, 2], [3, 12, 0]],
+ [[7, 3, 0], [2, 0, -1], [1, 12, 0]],
+ ],
+ [
+ [2, 0, 8, 7],
+ [0, 0, 0, 0],
+ ],
+ dict(image_w=10, image_h=10, margin=2),
+ ),
+ ],
+)
+def test_bbox_from_keypoints(keypoints, expected_bboxes, params):
+ keypoints = np.asarray(keypoints, dtype=float)
+ bboxes = utils.bbox_from_keypoints(keypoints, **params)
+ expected_bboxes = np.asarray(expected_bboxes, dtype=float)
+ np.testing.assert_array_almost_equal(bboxes, expected_bboxes)
diff --git a/tests/pose_estimation_pytorch/models/target_generators/test_heatmap_targets.py b/tests/pose_estimation_pytorch/models/target_generators/test_heatmap_targets.py
new file mode 100644
index 0000000000..d6641f265e
--- /dev/null
+++ b/tests/pose_estimation_pytorch/models/target_generators/test_heatmap_targets.py
@@ -0,0 +1,137 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the heatmap target generators (plateau and gaussian)"""
+import numpy as np
+import torch
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.heatmap_targets import (
+ HeatmapGaussianGenerator,
+)
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "dist_thresh": 3,
+ "num_heatmaps": 1,
+ "in_shape": (3, 3),
+ "out_shape": (3, 3),
+ "centers": [(1, 1)],
+ "expected_output": [
+ [0.7788, 0.8825, 0.7788],
+ [0.8825, 1.0000, 0.8825],
+ [0.7788, 0.8825, 0.7788],
+ ],
+ },
+ {
+ "dist_thresh": 3,
+ "num_heatmaps": 1,
+ "in_shape": (5, 5),
+ "out_shape": (5, 5),
+ "centers": [[1, 1], [2, 2]],
+ "expected_output": [
+ [0.7788, 0.8825, 0.7788, 0.5353, 0.3679],
+ [0.8825, 1.0000, 0.8825, 0.7788, 0.5353],
+ [0.7788, 0.8825, 1.0000, 0.8825, 0.6065],
+ [0.5353, 0.7788, 0.8825, 0.7788, 0.5353],
+ [0.3679, 0.5353, 0.6065, 0.5353, 0.3679],
+ ],
+ },
+ {
+ "dist_thresh": 1,
+ "num_heatmaps": 1,
+ "in_shape": (4, 4),
+ "out_shape": (4, 4),
+ "centers": [[1, 1]],
+ "expected_output": [
+ [0.1054, 0.3247, 0.1054, 0.0036],
+ [0.3247, 1.0, 0.3247, 0.0111],
+ [0.1054, 0.3247, 0.1054, 0.0036],
+ [0.0036, 0.0111, 0.0036, 0.0001]
+ ],
+ },
+ ],
+)
+def test_gaussian_heatmap_generation_single_keypoint(data):
+ dist_thresh = data["dist_thresh"]
+ generator = HeatmapGaussianGenerator(
+ num_heatmaps=data["num_heatmaps"],
+ pos_dist_thresh=dist_thresh,
+ heatmap_mode=HeatmapGaussianGenerator.Mode.KEYPOINT,
+ generate_locref=False,
+ )
+ stride = data["in_shape"][0] / data["out_shape"][0]
+ outputs = torch.zeros((1, data["num_heatmaps"], *data["out_shape"]))
+ ann_shape = (1, len(data["centers"]), data["num_heatmaps"], 2)
+ annotations = {
+ "keypoints": torch.tensor(data["centers"]).reshape(ann_shape) # x, y
+ }
+ targets = generator(stride, {"heatmap": outputs}, annotations)
+
+ print("Targets")
+ print(targets["heatmap"]["target"])
+ print()
+ np.testing.assert_almost_equal(
+ targets["heatmap"]["target"].cpu().numpy().reshape(data["out_shape"]),
+ np.array(data["expected_output"]),
+ decimal=3,
+ )
+
+
+@pytest.mark.parametrize(
+ "batch_size, num_keypoints, image_size",
+ [(2, 2, (64, 64)), (1, 5, (48, 64)), (15, 50, (64, 48))],
+)
+def test_random_gaussian_target_generation(
+ batch_size: int, num_keypoints: int, image_size: tuple, num_animals=1
+):
+ # generate annotations
+ annotations = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_keypoints, 2)
+ )
+ } # batch size, num animals, num keypoints, 2 for x,y
+
+ # model stride 1
+ stride = 1
+
+ # generate predictions
+ predicted_heatmaps = {
+ "heatmap": torch.zeros((batch_size, num_keypoints, *image_size))
+ }
+
+ # generate heatmap
+ generator = HeatmapGaussianGenerator(
+ num_heatmaps=num_keypoints,
+ pos_dist_thresh=17,
+ heatmap_mode=HeatmapGaussianGenerator.Mode.KEYPOINT,
+ generate_locref=False,
+ )
+ targets = generator(stride, predicted_heatmaps, annotations)
+ target_heatmap = targets["heatmap"]["target"].reshape(
+ batch_size, num_keypoints, image_size[0] * image_size[1]
+ )
+
+ # get coords of max value of the heatmap
+ gaus_max = torch.argmax(target_heatmap, dim=2)
+
+ # get unraveled coords
+ x = gaus_max % image_size[1]
+ y = gaus_max // image_size[1]
+
+ # get heatmap center tensor
+ predict_kp = torch.stack((x, y), dim=-1)
+ # Remove num_animals dimension - only one animal is supported
+ annotations["keypoints"] = torch.squeeze(annotations["keypoints"], dim=1)
+
+ # compare heatmap center to annotation
+ assert torch.eq(annotations["keypoints"], predict_kp).all().item()
diff --git a/tests/pose_estimation_pytorch/models/target_generators/test_plateau_targets.py b/tests/pose_estimation_pytorch/models/target_generators/test_plateau_targets.py
new file mode 100644
index 0000000000..4aa7133a4d
--- /dev/null
+++ b/tests/pose_estimation_pytorch/models/target_generators/test_plateau_targets.py
@@ -0,0 +1,89 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the heatmap target generators (plateau and gaussian)"""
+import numpy as np
+import torch
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators.heatmap_targets import (
+ HeatmapGenerator,
+ HeatmapPlateauGenerator,
+)
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ {
+ "dist_thresh": 1,
+ "num_heatmaps": 1,
+ "in_shape": (3, 3),
+ "out_shape": (3, 3),
+ "centers": [(1, 1)],
+ "expected_output": [
+ [0., 1., 0.],
+ [1., 1., 1.],
+ [0., 1., 0.],
+ ],
+ },
+ {
+ "dist_thresh": 2,
+ "num_heatmaps": 1,
+ "in_shape": (5, 5),
+ "out_shape": (5, 5),
+ "centers": [[1, 1], [2, 2]],
+ "expected_output": [
+ [1., 1., 1., 0., 0.],
+ [1., 1., 1., 1., 0.],
+ [1., 1., 1., 1., 1.],
+ [0., 1., 1., 1., 0.],
+ [0., 0., 1., 0., 0.],
+ ],
+ },
+ {
+ "dist_thresh": 2,
+ "num_heatmaps": 1,
+ "in_shape": (4, 4),
+ "out_shape": (4, 4),
+ "centers": [[1, 1]],
+ "expected_output": [
+ [1., 1., 1., 0.],
+ [1., 1., 1., 1.],
+ [1., 1., 1., 0.],
+ [0., 1., 0., 0.],
+ ],
+ },
+ ],
+)
+def test_plateau_heatmap_generation_single_keypoint(data):
+ dist_thresh = data["dist_thresh"]
+ generator = HeatmapPlateauGenerator(
+ num_heatmaps=data["num_heatmaps"],
+ pos_dist_thresh=dist_thresh,
+ heatmap_mode=HeatmapGenerator.Mode.KEYPOINT,
+ generate_locref=False,
+ )
+ stride = data["in_shape"][0] / data["out_shape"][0]
+ outputs = torch.zeros((1, data["num_heatmaps"], *data["out_shape"]))
+ ann_shape = (1, len(data["centers"]), data["num_heatmaps"], 2)
+ annotations = {
+ "keypoints": torch.tensor(data["centers"]).reshape(ann_shape) # x, y
+ }
+ targets = generator(stride, {"heatmap": outputs}, annotations)
+
+ print("Targets")
+ print(targets["heatmap"]["target"])
+ print()
+ np.testing.assert_almost_equal(
+ targets["heatmap"]["target"].cpu().numpy().reshape(data["out_shape"]),
+ np.array(data["expected_output"]),
+ decimal=3,
+ )
diff --git a/tests/tests_modelzoo.py b/tests/pose_estimation_pytorch/modelzoo/test_download.py
similarity index 91%
rename from tests/tests_modelzoo.py
rename to tests/pose_estimation_pytorch/modelzoo/test_download.py
index 555c590307..06cc9857e1 100644
--- a/tests/tests_modelzoo.py
+++ b/tests/pose_estimation_pytorch/modelzoo/test_download.py
@@ -4,12 +4,13 @@
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
-# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
-import dlclibrary
import os
+
+import dlclibrary
import pytest
from dlclibrary.dlcmodelzoo.modelzoo_download import MODELOPTIONS
@@ -29,7 +30,7 @@ def test_download_huggingface_wrong_model():
dlclibrary.download_huggingface_model("wrong_model_name")
-@pytest.mark.skip
+@pytest.mark.skip(reason="slow")
@pytest.mark.parametrize("model", MODELOPTIONS)
def test_download_all_models(tmp_path_factory, model):
test_download_huggingface_model(tmp_path_factory, model)
diff --git a/tests/pose_estimation_pytorch/modelzoo/test_load_superanimal_models.py b/tests/pose_estimation_pytorch/modelzoo/test_load_superanimal_models.py
new file mode 100644
index 0000000000..bdd93a3e10
--- /dev/null
+++ b/tests/pose_estimation_pytorch/modelzoo/test_load_superanimal_models.py
@@ -0,0 +1,31 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import dlclibrary
+import pytest
+import torch
+
+from deeplabcut.pose_estimation_pytorch.modelzoo import get_super_animal_snapshot_path
+
+
+@pytest.mark.skip(reason="require-models")
+def test_load_superanimal_models_weights_only():
+ super_animal_names = dlclibrary.get_available_datasets()
+ for super_animal in super_animal_names:
+ print(f"\nTesting {super_animal}")
+ for detector in dlclibrary.get_available_detectors(super_animal):
+ print(super_animal, detector)
+ path = get_super_animal_snapshot_path(super_animal, detector)
+ snapshot = torch.load(path, map_location="cpu", weights_only=True)
+
+ for pose_model in dlclibrary.get_available_models(super_animal):
+ print(super_animal, pose_model)
+ path = get_super_animal_snapshot_path(super_animal, pose_model)
+ snapshot = torch.load(path, map_location="cpu", weights_only=True)
diff --git a/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py b/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py
new file mode 100644
index 0000000000..f3a7b15c7a
--- /dev/null
+++ b/tests/pose_estimation_pytorch/modelzoo/test_modelzoo_utils.py
@@ -0,0 +1,35 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
+
+
+@pytest.mark.parametrize(
+ "super_animal", ["superanimal_quadruped", "superanimal_topviewmouse"]
+)
+@pytest.mark.parametrize("model_name", ["hrnet_w32"])
+@pytest.mark.parametrize("detector_name", [None, "fasterrcnn_resnet50_fpn_v2"])
+def test_get_config_model_paths(super_animal, model_name, detector_name):
+ model_config = modelzoo.load_super_animal_config(
+ super_animal=super_animal,
+ model_name=model_name,
+ detector_name=detector_name,
+ )
+
+ assert isinstance(model_config, dict)
+ if detector_name is None:
+ assert model_config["method"].lower() == "bu"
+ assert "detector" not in model_config
+ else:
+ assert model_config["method"].lower() == "td"
+ assert "detector" in model_config
diff --git a/tests/pose_estimation_pytorch/modelzoo/test_webapp.py b/tests/pose_estimation_pytorch/modelzoo/test_webapp.py
new file mode 100644
index 0000000000..7f9018f123
--- /dev/null
+++ b/tests/pose_estimation_pytorch/modelzoo/test_webapp.py
@@ -0,0 +1,81 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+
+import cv2
+import numpy as np
+import pytest
+
+from deeplabcut.modelzoo.webapp.inference import SuperanimalPyTorchInference
+from deeplabcut.utils import auxiliaryfunctions
+
+
+@pytest.mark.parametrize("max_individuals", [1, 3])
+@pytest.mark.parametrize(
+ "project_name", ["superanimal_quadruped", "superanimal_topviewmouse"]
+)
+@pytest.mark.parametrize("pose_model_type", ["hrnet_w32"])
+def test_class_init(project_name, pose_model_type, max_individuals):
+ inference_pipeline = SuperanimalPyTorchInference(
+ project_name, pose_model_type, max_individuals=max_individuals
+ )
+
+ assert isinstance(inference_pipeline.config, dict)
+ assert inference_pipeline.config["metadata"]["bodyparts"]
+ assert len(inference_pipeline.config["metadata"]["bodyparts"]) > 0
+
+
+@pytest.mark.skip(reason="require-models")
+@pytest.mark.parametrize(
+ "project_name", ["superanimal_quadruped", "superanimal_topviewmouse"]
+)
+@pytest.mark.parametrize("pose_model_type", ["hrnet_w32"])
+def test_runner_init(project_name, pose_model_type):
+ inference_pipeline = SuperanimalPyTorchInference(
+ project_name, pose_model_type, max_individuals=1
+ )
+ weight_folder = f"{auxiliaryfunctions.get_deeplabcut_path()}/modelzoo/checkpoints"
+ snapshot_path = f"{weight_folder}/{project_name}_{pose_model_type}.pth"
+ detector_path = f"{weight_folder}/{project_name}_fasterrcnn.pt"
+
+ inference_pipeline.initialize_models(snapshot_path, detector_path)
+
+ assert inference_pipeline.models.pose_runner
+ assert inference_pipeline.models.detector_runner
+
+
+@pytest.mark.skip(reason="require-models")
+@pytest.mark.parametrize("max_individuals", [10, 4, 1])
+@pytest.mark.parametrize(
+ "project_name", ["superanimal_quadruped", "superanimal_topviewmouse"]
+)
+@pytest.mark.parametrize("pose_model_type", ["hrnet_w32"])
+def test_predict(project_name, pose_model_type, max_individuals):
+ inference_pipeline = SuperanimalPyTorchInference(
+ project_name, pose_model_type, max_individuals=max_individuals
+ )
+ image_path = "img0001.png"
+ weight_folder = f"{auxiliaryfunctions.get_deeplabcut_path()}/modelzoo/checkpoints"
+ snapshot_path = f"{weight_folder}/{project_name}_{pose_model_type}.pth"
+ detector_path = f"{weight_folder}/{project_name}_fasterrcnn.pt"
+
+ inference_pipeline.initialize_models(snapshot_path, detector_path)
+ frame = {image_path: np.random.rand(100, 100, 3)}
+ response = inference_pipeline.predict(frame)
+
+ assert isinstance(response, dict)
+ assert response["joint_names"] == inference_pipeline.config["bodyparts"]
+ assert response["predictions"][0]["markers"].shape == (
+ max_individuals,
+ len(inference_pipeline.config["bodyparts"]),
+ 3,
+ )
+ assert response["predictions"][0]["image_path"] == image_path
diff --git a/tests/pose_estimation_pytorch/other/test_api_utils.py b/tests/pose_estimation_pytorch/other/test_api_utils.py
new file mode 100644
index 0000000000..bd2cf125bf
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_api_utils.py
@@ -0,0 +1,96 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import random
+
+import numpy as np
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.data.transforms as transforms
+
+transform_dicts = [
+ {"auto_padding": {"pad_height_divisor": 64, "pad_width_divisor": 27}},
+ {"resize": {"height": 512, "width": 256, "keep_ration": True}},
+ {
+ "covering": True,
+ "gaussian_noise": 12.75,
+ "hist_eq": True,
+ "motion_blur": True,
+ "normalize_images": True,
+ "rotation": 30,
+ "scale_jitter": [0.5, 1.25],
+ "auto_padding": {"pad_width_divisor": 64, "pad_height_divisor": 27},
+ },
+ {
+ "covering": True,
+ "gaussian_noise": 100,
+ "hist_eq": True,
+ "motion_blur": True,
+ "normalize_images": True,
+ "rotation": 180,
+ "scale_jitter": [0.03, 20],
+ "auto_padding": {"pad_width_divisor": 64, "pad_height_divisor": 27},
+ },
+]
+
+
+def _get_random_params(transform_idx):
+ return (
+ transform_dicts[transform_idx],
+ (random.randint(100, 1000), random.randint(100, 1000)),
+ random.randint(1, 100),
+ random.randint(1, 100),
+ )
+
+
+@pytest.mark.parametrize(
+ "transform_dict, size_image, num_keypoints, num_animals",
+ [_get_random_params(i) for i in range(4)],
+)
+def test_build_transforms(transform_dict, size_image, num_keypoints, num_animals):
+ transform_bbox_aug = transforms.build_transforms(transform_dict)
+ w, h = size_image
+ for i in range(10):
+ test_image = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
+ bboxes = np.random.randint(0, min(w - 1, h - 1), (num_animals, 4))
+ bboxes[:, 2] = w - bboxes[:, 0]
+ bboxes[:, 3] = h - bboxes[:, 1]
+ keypoints = np.random.randint(0, min(w, h), (num_keypoints, 2))
+
+ with pytest.raises(Exception):
+ transformed = transform_bbox_aug(image=test_image)
+ transformed = transform_bbox_aug(image=test_image, bboxes=bboxes.copy())
+ transformed = transform_bbox_aug(
+ image=test_image, keypoints=keypoints.copy(), bboxes=bboxes.copy()
+ )
+
+ transformed_with_bbox = transform_bbox_aug(
+ image=test_image,
+ keypoints=keypoints.copy(),
+ bboxes=bboxes.copy(),
+ bbox_labels=np.arange(num_animals),
+ class_labels=[0 for _ in range(len(keypoints))]
+ )
+
+ if "resize" in transform_dict.keys():
+ assert transformed_with_bbox["image"].shape[:2] == (
+ transform_dict["resize"]["height"],
+ transform_dict["resize"]["width"],
+ )
+
+ if "auto_padding" in transform_dict.keys():
+ modh, modw = (
+ transform_dict["auto_padding"]["pad_height_divisor"],
+ transform_dict["auto_padding"]["pad_width_divisor"],
+ )
+ assert transformed_with_bbox["image"].shape[0] % modh == 0
+ assert transformed_with_bbox["image"].shape[1] % modw == 0
+
+ assert len(transformed_with_bbox["keypoints"]) == len(keypoints)
diff --git a/tests/pose_estimation_pytorch/other/test_configs/config.yaml b/tests/pose_estimation_pytorch/other/test_configs/config.yaml
new file mode 100644
index 0000000000..15ad6f4678
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_configs/config.yaml
@@ -0,0 +1,106 @@
+ # Project definitions (do not edit)
+Task: openfield
+scorer: Pranav
+date: Aug20
+multianimalproject: false
+identity:
+
+ # Project path (change when moving around)
+project_path: /home/quentin/datasets/Openfield_pytorch
+
+ # Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+ /Data/openfield-Pranav-2018-08-20/videos/m1s1.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m1s2.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m2s1.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m3s1.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m3s2.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m4s1.mp4:
+ crop: 0, 640, 0, 480
+ /Data/openfield-Pranav-2018-08-20/videos/m5s1.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m6s1.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m6s2.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m7s1.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m7s2.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m7s3.mp4:
+ crop: 0, 800, 0, 800
+ /Data/openfield-Pranav-2018-08-20/videos/m8s1.mp4:
+ crop: 0, 800, 0, 800
+
+ /Users/mwmathis/Downloads/ARCricket1.avi:
+ crop: 0, 720, 0, 540
+bodyparts:
+- snout
+- leftear
+- rightear
+- tailbase
+
+flipped_keypoints:
+- 0
+- 2
+- 1
+- 3
+
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+start: 0
+stop: 1
+numframes2pick: 20
+
+ # Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff: 0.4
+dotsize: 8
+alphavalue: 0.7
+colormap: jet
+
+ # Training,Evaluation and Analysis configuration
+TrainingFraction:
+- 0.95
+iteration: 1
+default_net_type: resnet_50
+default_augmenter: default
+snapshotindex: -1
+batch_size: 1
+
+ # Cropping Parameters (for analysis and outlier frame detection)
+cropping: false
+ #if cropping is true for analysis, then set the values here:
+x1: 0
+x2: 640
+y1: 277
+y2: 624
+
+ # Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+- 50
+- 50
+move2corner: true
+croppedtraining:
diff --git a/tests/pose_estimation_pytorch/other/test_configs/pose_cfg.yaml b/tests/pose_estimation_pytorch/other/test_configs/pose_cfg.yaml
new file mode 100644
index 0000000000..ec41492bd4
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_configs/pose_cfg.yaml
@@ -0,0 +1,115 @@
+ # Project definitions (do not edit)
+Task:
+scorer:
+date:
+multianimalproject:
+identity:
+
+ # Project path (change when moving around)
+project_path: /home/quentin/datasets/Openfield_pytorch/dlc-models/iteration-1/openfieldAug20-trainset95shuffle1/train
+
+ # Annotation data set configuration (and individual video cropping parameters)
+video_sets:
+bodyparts:
+
+ # Fraction of video to start/stop when extracting frames for labeling/refinement
+start:
+stop:
+numframes2pick:
+
+ # Plotting configuration
+skeleton: []
+skeleton_color: black
+pcutoff:
+dotsize:
+alphavalue:
+colormap:
+
+ # Training,Evaluation and Analysis configuration
+TrainingFraction:
+iteration:
+default_net_type:
+default_augmenter:
+snapshotindex:
+batch_size: 1
+
+ # Cropping Parameters (for analysis and outlier frame detection)
+cropping:
+ #if cropping is true for analysis, then set the values here:
+x1:
+x2:
+y1:
+y2:
+
+ # Refinement configuration (parameters from annotation dataset configuration also relevant in this stage)
+corner2move2:
+move2corner:
+all_joints:
+- - 0
+- - 1
+- - 2
+- - 3
+all_joints_names:
+- snout
+- leftear
+- rightear
+- tailbase
+alpha_r: 0.02
+apply_prob: 0.5
+contrast:
+ clahe: true
+ claheratio: 0.1
+ histeq: true
+ histeqratio: 0.1
+convolution:
+ edge: false
+ emboss:
+ alpha:
+ - 0.0
+ - 1.0
+ strength:
+ - 0.5
+ - 1.5
+ embossratio: 0.1
+ sharpen: false
+ sharpenratio: 0.3
+cropratio: 0.4
+dataset: training-datasets/iteration-1/UnaugmentedDataSet_openfieldAug20/openfield_Pranav95shuffle1.mat
+dataset_type: default
+decay_steps: 30000
+display_iters: 1000
+global_scale: 0.8
+init_weights: /home/quentin/miniconda/envs/DEEPLABCUT/lib/python3.8/site-packages/deeplabcut/pose_estimation_tensorflow/models/pretrained/resnet_v1_50.ckpt
+intermediate_supervision: false
+intermediate_supervision_layer: 12
+location_refinement: true
+locref_huber_loss: true
+locref_loss_weight: 0.05
+locref_stdev: 7.2801
+lr_init: 0.0005
+max_input_size: 1500
+metadataset: training-datasets/iteration-1/UnaugmentedDataSet_openfieldAug20/Documentation_data-openfield_95shuffle1.pickle
+min_input_size: 64
+mirror: false
+multi_stage: false
+multi_step:
+- - 0.005
+ - 10000
+- - 0.02
+ - 430000
+- - 0.002
+ - 730000
+- - 0.001
+ - 1030000
+net_type: resnet_50
+num_joints: 4
+pairwise_huber_loss: false
+pairwise_predict: false
+partaffinityfield_predict: false
+pos_dist_thresh: 17
+rotation: 25
+rotratio: 0.4
+save_iters: 50000
+scale_jitter_lo: 0.5
+scale_jitter_up: 1.25
+scmap_type: plateau
diff --git a/tests/pose_estimation_pytorch/other/test_configs/pytorch_config.yaml b/tests/pose_estimation_pytorch/other/test_configs/pytorch_config.yaml
new file mode 100644
index 0000000000..0be2ca0ed8
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_configs/pytorch_config.yaml
@@ -0,0 +1,45 @@
+project_root: /home/quentin/datasets/Openfield_pytorch
+pose_cfg_path: /home/quentin/datasets/Openfield_pytorch/dlc-models/iteration-1/openfieldAug20-trainset95shuffle1/train/pose_cfg.yaml
+cfg_path: /home/quentin/datasets/Openfield_pytorch/config.yaml
+
+seed: 42
+device: 'cuda:2' #needs to be updated dynamically; some users might have CPUs
+model:
+ backbone:
+ type: 'ResNet'
+ pretrained: 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
+ heatmap_head:
+ type: 'SimpleHead'
+ channels: [ 2048, 1024, 4 ]
+ kernel_size: [ 2, 2 ]
+ strides: [ 2, 2 ]
+ locref_head:
+ type: 'SimpleHead'
+ channels: [ 2048, 1024, 8 ]
+ kernel_size: [ 2, 2 ]
+ strides: [ 2, 2 ]
+ pose_model:
+ stride: 8
+ heatmap_type: 'plateau'
+optimizer:
+ type: 'SGD'
+ params:
+ lr: 0.005
+scheduler:
+ type: "LRListScheduler"
+ params:
+ milestones : [10, 430]
+ lr_list : [[0.02], [0.002]]
+criterion:
+ type: 'PoseLoss'
+ loss_weight_locref: 0.1
+ locref_huber_loss: True
+#logger:
+# type: 'WandbLogger'
+# project_name: 'deeplabcut'
+# run_name: 'tmp'
+solver:
+ type: 'BottomUpSingleAnimalSolver'
+pos_dist_thresh : 17
+batch_size: 1
+epochs: 600
diff --git a/tests/pose_estimation_pytorch/other/test_custom_transforms.py b/tests/pose_estimation_pytorch/other/test_custom_transforms.py
new file mode 100644
index 0000000000..f312fc9978
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_custom_transforms.py
@@ -0,0 +1,55 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.data import transforms
+
+
+@pytest.mark.parametrize("width, height", [(200, 200), (300, 300), (400, 400)])
+def test_keypoint_aware_cropping(width, height):
+ fake_image = np.empty((600, 600, 3))
+ fake_keypoints = [(i * 100, i * 100, 0, 0) for i in range(1, 6)]
+ aug = transforms.KeypointAwareCrop(
+ width=width, height=height, crop_sampling="density"
+ )
+ transformed = aug(image=fake_image, keypoints=fake_keypoints)
+ assert transformed["image"].shape[:2] == (height, width)
+ # Ensure at least a keypoint is visible in each crop
+ assert len(transformed["keypoints"])
+
+
+def test_grayscale():
+ fake_image = np.ones((600, 600, 3))
+ fake_image *= np.random.uniform(0, 255, size=fake_image.shape)
+ fake_image = fake_image.astype(np.uint8)
+ gray = transforms.Grayscale(alpha=1, p=1)
+ aug_image = gray(image=fake_image)["image"]
+ assert aug_image.shape == fake_image.shape
+
+ gray = transforms.Grayscale(alpha=0, p=1)
+ aug_image = gray(image=fake_image)["image"]
+ assert np.allclose(fake_image, aug_image)
+
+ with pytest.warns(UserWarning, match="clipped"):
+ gray = transforms.Grayscale(alpha=1.5)
+ assert gray.alpha == 1
+
+
+def test_coarse_dropout():
+ fake_image = np.ones((300, 300, 3))
+ fake_image *= np.random.uniform(0, 255, size=fake_image.shape)
+ fake_image = fake_image.astype(np.uint8)
+ cd = transforms.CoarseDropout(max_height=0.9999, max_width=0.9999, p=1)
+ kpts = np.random.rand(10, 2) * 300
+ aug_kpts = cd(image=fake_image, keypoints=kpts)["keypoints"]
+ assert len(aug_kpts) == kpts.shape[0]
+ assert np.isnan([c for kpt in aug_kpts for c in kpt]).all()
diff --git a/tests/pose_estimation_pytorch/other/test_data_helper.py b/tests/pose_estimation_pytorch/other/test_data_helper.py
new file mode 100644
index 0000000000..bfb83dde7f
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_data_helper.py
@@ -0,0 +1,94 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from __future__ import annotations
+
+import os
+from unittest.mock import patch, Mock
+from zipfile import Path
+
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.data.dlcloader import DLCLoader
+from deeplabcut.pose_estimation_pytorch.data.utils import merge_list_of_dicts
+from deeplabcut.generate_training_dataset import create_training_dataset
+
+
+def mock_aux() -> Mock:
+ aux_functions = Mock()
+ aux_functions.read_plainconfig = Mock()
+ aux_functions.read_plainconfig.return_value = {}
+ return aux_functions
+
+
+@patch("deeplabcut.pose_estimation_pytorch.data.base.auxiliaryfunctions", mock_aux())
+def _get_loader(project_root):
+ if not (Path(project_root) / "training-datasets").exists():
+ create_training_dataset(config=str(Path(project_root) / "config.yaml"))
+ return DLCLoader(Path(project_root) / "config.yaml", shuffle=1)
+
+
+@pytest.mark.skip
+@pytest.mark.parametrize("repo_path", ["/home/anastasiia/DLCdev"])
+def test_propertymeta_project(repo_path):
+ project_root = os.path.join(repo_path, "examples", "openfield-Pranav-2018-10-30")
+ dlc_loader = _get_loader(project_root)
+
+ for prop in dlc_loader.properties:
+ print(prop, getattr(dlc_loader, prop))
+
+
+@pytest.mark.skip
+@pytest.mark.parametrize(
+ "repo_path, mode",
+ [("/home/anastasiia/DLCdev", "train"), ("/home/anastasiia/DLCdev", "test")],
+)
+def test_propertymeta_dataset(repo_path, mode):
+ repo_path = "/home/anastasiia/DLCdev"
+ mode = "train"
+ project_root = os.path.join(repo_path, "examples", "openfield-Pranav-2018-10-30")
+ dlc_loader = _get_loader(project_root)
+ dataset = dlc_loader.create_dataset(transform=None, mode=mode)
+
+ for prop in dataset.properties:
+ print(prop, getattr(dataset, prop))
+
+
+@pytest.mark.parametrize(
+ "list_dicts, keys_to_include",
+ [
+ ([{"a": 1, "b": 2}, {"a": 3, "b": 4}], ["a"]),
+ (
+ [
+ *[
+ {
+ "keypoints": np.random.randn(27, 3),
+ "images": np.random.randn(256, 192),
+ }
+ ]
+ * 10
+ ],
+ [*["keypoints", "images"] * 10],
+ ),
+ ],
+)
+def test_merge_list_of_dicts(list_dicts, keys_to_include):
+ result_dict = merge_list_of_dicts(list_dicts, keys_to_include)
+ expected_result_dict = {}
+ for dictionary in list_dicts:
+ for key in dictionary:
+ if key not in keys_to_include:
+ continue
+ else:
+ if key not in expected_result_dict:
+ expected_result_dict[key] = []
+ expected_result_dict[key].append(dictionary[key])
+ assert result_dict == expected_result_dict
diff --git a/tests/pose_estimation_pytorch/other/test_dataset.py b/tests/pose_estimation_pytorch/other/test_dataset.py
new file mode 100644
index 0000000000..6894d8fbe2
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_dataset.py
@@ -0,0 +1,197 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+import random
+from pathlib import Path
+from unittest.mock import Mock, patch
+
+import albumentations as A
+import pytest
+from torch.utils.data import DataLoader
+
+import deeplabcut.pose_estimation_pytorch as dlc
+import deeplabcut.utils.auxiliaryfunctions as dlc_auxfun
+from deeplabcut.core.engine import Engine
+from deeplabcut.generate_training_dataset import create_training_dataset
+
+
+def mock_config() -> Mock:
+ aux_functions = Mock()
+ aux_functions.read_config_as_dict = Mock()
+ aux_functions.read_config_as_dict.return_value = {
+ "data": {"train": {}, "inference": {}},
+ "metadata": {
+ "project_path": "",
+ "pose_config_path": "",
+ "bodyparts": ["snout", "leftear", "rightear", "tailbase"],
+ "unique_bodyparts": [],
+ "individuals": ["animal"],
+ "with_identity": False,
+ },
+ "method": "bu",
+ }
+ return aux_functions
+
+
+@patch("deeplabcut.pose_estimation_pytorch.data.base.config", mock_config())
+def _get_dataset(path, transform, mode="train"):
+ project_root = Path(path)
+ if not (project_root / "training-datasets").exists():
+ print(str(project_root / "config.yaml"))
+ create_training_dataset(
+ config=str(project_root / "config.yaml"),
+ net_type="resnet_50",
+ engine=Engine.PYTORCH,
+ )
+
+ loader = dlc.DLCLoader(Path(project_root) / "config.yaml", shuffle=1)
+ dataset = loader.create_dataset(transform=transform, mode=mode)
+ return dataset
+
+
+def _get_openfield_dataset(transform=None):
+ dlc_path = dlc_auxfun.get_deeplabcut_path()
+ repo_path = os.path.dirname(dlc_path)
+ openfield_path = os.path.join(repo_path, "examples", "openfield-Pranav-2018-10-30")
+
+ return _get_dataset(openfield_path, transform=transform)
+
+
+key_set = {
+ "offsets",
+ "path",
+ "scales",
+ "image",
+ "original_size",
+ "annotations",
+ "image_id",
+}
+anno_key_set = {
+ "keypoints",
+ "keypoints_unique",
+ "with_center_keypoints",
+ "area",
+ "boxes",
+ "is_crowd",
+ "labels",
+ "individual_ids",
+}
+
+
+@pytest.mark.parametrize("batch_size", [1, 2, random.randint(2, 20)])
+def test_iter_all_dataset_no_transform(batch_size):
+ if batch_size > 1: # if batched, all images need to be the same size
+ transform = A.Compose(
+ [A.Resize(512, 512)],
+ keypoint_params=A.KeypointParams(format="xy"),
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ else:
+ transform = A.Compose(
+ [A.Normalize()],
+ keypoint_params=A.KeypointParams(format="xy"),
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ dataset = _get_openfield_dataset(transform=transform)
+ dataloader = DataLoader(dataset, batch_size=batch_size)
+ max_num_animals = dataset.parameters.max_num_animals
+ num_keypoints = dataset.parameters.num_joints
+ for i, item in enumerate(dataloader):
+ is_last_batch = i == (len(dataloader) - 1)
+ assert (
+ set(item.keys()) == key_set
+ ), "the key returned don't match the required ones"
+
+ anno = item["annotations"]
+ assert (
+ set(anno.keys()) == anno_key_set
+ ), "the annotation keys returned don't match the required ones"
+
+ assert (len(item["image"].shape) == 4) and (
+ (item["image"].shape[:2] == (batch_size, 3)) or is_last_batch
+ ), "image shape is not (batch_size, 3, h, w)"
+
+ b, _, h, w = item["image"].shape
+ kpts, bboxes = anno["keypoints"], anno["boxes"]
+ assert (
+ kpts.shape == (batch_size, max_num_animals, num_keypoints, 3)
+ or is_last_batch
+ ), "keypoints have the wrong shape"
+ assert (
+ bboxes.shape == (batch_size, max_num_animals, 4) or is_last_batch
+ ), "boxes have the wrong shape"
+ assert ((bboxes[:, :, 0] + bboxes[:, :, 2]) <= w).all() and (
+ (bboxes[:, :, 1] + bboxes[:, :, 3]) <= h
+ ).all(), "boxes don't seem to be un the format (x, y, w, h)"
+
+
+def _generate_random_test_values_aug(min_exa):
+ batch_size = random.randint(1, 20)
+ x_size = random.randint(50, 600)
+ y_size = random.randint(50, 600)
+ exaggeration = random.randint(min_exa, 99)
+
+ return batch_size, x_size, y_size, exaggeration
+
+
+@pytest.mark.parametrize(
+ "batch_size, x_size, y_size, exaggeration",
+ [
+ (1, 512, 512, 1),
+ _generate_random_test_values_aug(1),
+ _generate_random_test_values_aug(50),
+ ],
+)
+def test_iter_all_augmented_dataset(batch_size, x_size, y_size, exaggeration):
+ transform = A.Compose(
+ [
+ A.Affine(
+ scale=(1 - exaggeration * 0.01, 1 + exaggeration),
+ rotate=(-exaggeration * 2, exaggeration * 2),
+ translate_px=(-exaggeration * 10, exaggeration * 10),
+ ),
+ A.Resize(y_size, x_size),
+ ],
+ keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
+ bbox_params=A.BboxParams(format="coco", label_fields=["bbox_labels"]),
+ )
+ dataset = _get_openfield_dataset(transform=transform)
+ dataloader = DataLoader(dataset, batch_size=batch_size)
+ max_num_animals = dataset.parameters.max_num_animals
+ num_keypoints = dataset.parameters.num_joints
+ for i, item in enumerate(dataloader):
+ is_last_batch = i == (len(dataloader) - 1)
+ assert (
+ set(item.keys()) == key_set
+ ), "the key returned don't match the required ones"
+
+ anno = item["annotations"]
+ assert (
+ set(anno.keys()) == anno_key_set
+ ), "the annotation keys returned don't match the required ones"
+
+ assert (len(item["image"].shape) == 4) and (
+ (item["image"].shape[:2] == (batch_size, 3)) or is_last_batch
+ ), "image shape is not (batch_size, 3, h, w)"
+
+ kpts, bboxes = anno["keypoints"], anno["boxes"]
+ b, _, h, w = item["image"].shape
+ assert (h == y_size) and (w == x_size)
+ assert (
+ kpts.shape == (batch_size, max_num_animals, num_keypoints, 3)
+ or is_last_batch
+ ), "keypoints have the wrong shape"
+ assert (
+ bboxes.shape == (batch_size, max_num_animals, 4) or is_last_batch
+ ), "boxes have the wrong shape"
+ assert ((bboxes[:, :, 0] + bboxes[:, :, 2]) <= w).all() and (
+ (bboxes[:, :, 1] + bboxes[:, :, 3]) <= h
+ ).all()
diff --git a/tests/pose_estimation_pytorch/other/test_gaussian_targets.py b/tests/pose_estimation_pytorch/other/test_gaussian_targets.py
new file mode 100644
index 0000000000..57cb212fa8
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_gaussian_targets.py
@@ -0,0 +1,60 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import pytest
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators import HeatmapGaussianGenerator
+
+
+@pytest.mark.parametrize(
+ "batch_size, num_keypoints, image_size",
+ [(2, 2, (64, 64)), (1, 5, (48, 64)), (15, 50, (64, 48))],
+)
+def test_gaussian_target_generation(
+ batch_size: int, num_keypoints: int, image_size: tuple, num_animals=1
+):
+ # generate annotations
+ labels = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_keypoints, 2)
+ )
+ } # batch size, num animals, num keypoints, 2 for x,y
+ # generate predictions
+ stride = 1
+ prediction = {
+ "heatmap": torch.rand((batch_size, num_keypoints, *image_size[:2])),
+ "locref": torch.rand((batch_size, 2 * num_keypoints, *image_size[:2])),
+ }
+
+ # generate heatmap
+ output = HeatmapGaussianGenerator(
+ num_heatmaps=num_keypoints,
+ pos_dist_thresh=17,
+ locref_std=5.0,
+ )
+ output = output(stride, prediction, labels)["heatmap"]["target"].reshape(
+ batch_size, num_keypoints, image_size[0] * image_size[1]
+ )
+
+ # get coords of max value of the heatmap
+ gaus_max = torch.argmax(output, dim=2)
+
+ # get unraveled coords
+ x = gaus_max % image_size[1]
+ y = gaus_max // image_size[1]
+
+ # get heatmap center tensor
+ predict_kp = torch.stack((x, y), dim=-1)
+ # Remove num_animals dimension - only one animal is supported
+ labels["keypoints"] = torch.squeeze(labels["keypoints"], dim=1)
+
+ # compare heatmap center to annotation
+ assert torch.eq(labels["keypoints"], predict_kp).all().item()
diff --git a/tests/pose_estimation_pytorch/other/test_heatmap_plateau_targets.py b/tests/pose_estimation_pytorch/other/test_heatmap_plateau_targets.py
new file mode 100644
index 0000000000..2508d90860
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_heatmap_plateau_targets.py
@@ -0,0 +1,208 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+from typing import Tuple
+
+import pytest
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators import HeatmapPlateauGenerator
+
+
+def get_target(
+ batch_size: int,
+ num_animals: int,
+ num_joints: int,
+ image_size: Tuple[int, int],
+ locref_std: float,
+ pos_dist_thresh: int,
+):
+ """Summary
+ Getting the target generator for certain annotations, predictions and image size.
+
+ Args:
+ batch_size (int): number of images
+ num_animals (int): number of animals
+ num_joints (int): number of bodyparts
+ image_size (tuple): image size in pixels
+ locref_std (float): scaling factor
+ pos_dist_thresh (int): radius plateau on the heatmap
+
+ Returns:
+ target_output (dict): containing the heatmaps, locref_maps and locref_masks.
+ annotations (dict): containing input keypoint annotations.
+
+ Examples:
+ input:
+ batch_size = 1
+ num_animals = 1
+ num_joints = 6
+ image_size = (256,256)
+ locref_stdev = 7.2801
+ pos_dist_thresh = 17
+ output:
+
+ """
+ labels = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_joints, 2)
+ )
+ } # 2 for x,y coords
+ stride = 1
+ prediction = {
+ "heatmap": torch.rand((batch_size, num_joints, image_size[0], image_size[1])),
+ "locref": torch.rand((batch_size, 2 * num_joints, image_size[0], image_size[1])),
+ }
+ generator = HeatmapPlateauGenerator(
+ num_heatmaps=num_joints,
+ pos_dist_thresh=pos_dist_thresh,
+ locref_std=locref_std,
+ generate_locref=True,
+ )
+
+ targets_output = generator(stride, prediction, labels)
+ return targets_output, labels
+
+
+data = [(1, 1, 10, (256, 256), 7.2801, 17)]
+
+
+@pytest.mark.parametrize(
+ "batch_size, num_animals, num_joints, image_size, locref_stdev, pos_dist_thresh",
+ data,
+)
+def test_expected_output(
+ batch_size: int,
+ num_animals: int,
+ num_joints: int,
+ image_size: Tuple[int, int],
+ locref_stdev: float,
+ pos_dist_thresh: int,
+):
+ """Summary:
+ Testing if plateau targets return the expected output. We take a target generator from
+ get_target function. Given a sequence of random numbers for batch_size, num_animals etc., we assert if
+ it returns the expected heatmaps and locrefmaps, as well as checking if the output has the expected shape.
+
+ Args:
+ batch_size (int): number of images
+ num_animals (int): number of animals
+ num_joints (int): number of bodyparts
+ image_size (tuple): image size in pixels
+ locref_stdev (float): scaling factor
+ pos_dist_thresh (int): radius plateau on heatmap
+
+ Returns:
+ None
+
+ Examples:
+ input:
+ batch_size = 1
+ num_animals = 1
+ num_joints = 6
+ image_size = (256,256)
+ locref_stdev = 7.2801
+ pos_dist_thresh = 17
+ """
+ targets_output, annotations = get_target(
+ batch_size, num_animals, num_joints, image_size, locref_stdev, pos_dist_thresh
+ )
+
+ assert "heatmap" in targets_output
+ assert "locref" in targets_output
+ assert targets_output["heatmap"]["target"].shape == (
+ batch_size,
+ num_joints,
+ image_size[0],
+ image_size[1],
+ ) # heatmaps score output
+ assert targets_output["locref"]["weights"].shape == (
+ batch_size,
+ num_joints * 2,
+ image_size[0],
+ image_size[1],
+ )
+ assert targets_output["locref"]["target"].shape == (
+ batch_size,
+ num_joints * 2,
+ image_size[0],
+ image_size[1],
+ )
+
+
+data = [(1, 1, 10, (256, 256), 7.2801, 17)]
+
+
+@pytest.mark.parametrize(
+ "batch_size, num_animals, num_joints, image_size, locref_stdev, pos_dist_thresh",
+ data,
+)
+def test_single_animal(
+ batch_size: int,
+ num_animals: int,
+ num_joints: int,
+ image_size: Tuple[int, int],
+ locref_stdev: float,
+ pos_dist_thresh: int,
+):
+ """Summary
+ Testing, for single animals experiments (num_animals=1) if the distance between the expected keypoints
+ and the annotations keypoints is smaller than the radius plateau.
+
+ 'argmax' function returns the indices of the max values of all elements in the input tensor.
+ If there are multiple maximal values, such as in our case because it's a plateau, then the
+ indices of the first maximal value are returned. From this tensor we exctact x,y coords
+ and then concatenate these new tensors along a new dimension. Then, we assert if the distance between
+ each x,y element in annotations and predicted keypoints is smaller or equal to the 'pos_dist_thresh',
+ which represents the radius of the plateau heatmap.
+
+ Args:
+ batch_size (int): number of images
+ num_animals (int): number of animals
+ num_joints (int): number of bodyparts
+ image_size (tuple): image size in pixels
+ locref_stdev (float): scaling factor
+ pos_dist_thresh (int): radius plateau on heatmap
+
+ Returns:
+ None
+
+ Examples:
+ input:
+ batch_size = 1
+ num_animals = 1
+ num_joints = 6
+ image_size = (256,256)
+ locref_stdev = 7.2801
+ pos_dist_thresh = 17
+ """
+ targets_output, annotations = get_target(
+ batch_size, num_animals, num_joints, image_size, locref_stdev, pos_dist_thresh
+ )
+
+ targets_output = torch.tensor(
+ targets_output["heatmap"]["target"].reshape(1, 10, image_size[0] * image_size[1])
+ ) # converting from dict to tensor. 'argmax' works on tensors.
+
+ plt_max = torch.argmax(targets_output, dim=2)
+ # get unraveled coords
+ x = plt_max % image_size[1]
+ y = plt_max // image_size[1]
+
+ predict_kp = torch.stack((x, y), dim=-1)
+
+ predict_kp = predict_kp.float()
+
+ annotations["keypoints"] = torch.squeeze(annotations["keypoints"], dim=1)
+ annotations["keypoints"] = annotations["keypoints"].float()
+
+ dist = torch.norm(annotations["keypoints"] - predict_kp, p=2, dim=-1)
+ assert (dist <= pos_dist_thresh).all()
diff --git a/tests/pose_estimation_pytorch/other/test_helper.py b/tests/pose_estimation_pytorch/other/test_helper.py
new file mode 100644
index 0000000000..c7cd0fd6f1
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_helper.py
@@ -0,0 +1,21 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import torch
+
+
+def test_train_valid_call():
+ tmp_model = torch.nn.Linear(3, 10)
+ to_train_mode = getattr(tmp_model, "train")
+ to_train_mode()
+ assert tmp_model.training == True
+ to_valid_mode = getattr(tmp_model, "eval")
+ to_valid_mode()
+ assert tmp_model.training == False
diff --git a/tests/pose_estimation_pytorch/other/test_match_predictions_to_gt.py b/tests/pose_estimation_pytorch/other/test_match_predictions_to_gt.py
new file mode 100644
index 0000000000..943e5c9882
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_match_predictions_to_gt.py
@@ -0,0 +1,145 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+
+import numpy as np
+import pytest
+
+import deeplabcut.pose_estimation_pytorch.post_processing.match_predictions_to_gt as deeplabcut_torch_match_predictions_gt
+
+
+@pytest.fixture
+def animals_and_keypoints_invalid():
+ """Summary:
+ Fixture with invalid pred_kpts and gt_kpts shapes that will raise ValueErrors.
+
+ Returns:
+ tuple containing:
+ predicted keypoints(pred_kpts), of shape num_animals, num_keypoints, (x,y,score)
+ ground truth keypoints (gt_kpts), of shape num_animals, num_keypoints, (x,y)
+ individual names (indv_names)
+ """
+ gt_kpts = 2 * np.ones((6, 6, 3)) # num animals, num keypoints, (x,y,vis)
+ gt_kpts[:, :, :2] = np.random.rand(6, 6, 2)
+ pred_kpts = np.random.rand(6, 8, 3) # num animals, num keypoints, (x,y,score)
+ indv_names = ["indv1", "indv2"]
+ return pred_kpts, gt_kpts, indv_names
+
+
+@pytest.fixture
+def animals_and_keypoints():
+ """Summary:
+ Fixture with pred_kpts, gt_kpts shapes and indv_names.
+
+ Returns:
+ tuple containing:
+ predicted keypoints(pred_kpts), of shape num_animals, num_keypoints, (x,y,score)
+ ground truth keypoints (gt_kpts), of shape num_animals, num_keypoints, (x,y)
+ individual names (indv_names)
+ """
+ gt_kpts = 2 * np.ones((6, 6, 3)) # num animals, num keypoints, (x,y,vis)
+ gt_kpts[:, :, :2] = np.random.rand(6, 6, 2)
+
+ # adding score value because the shape of pred_kpts should be (6,6,3)
+ score = np.full((gt_kpts.shape[0], gt_kpts.shape[1], 1), 0.5)
+ pred_kpts = np.concatenate((gt_kpts, score), axis=2)
+ np.random.shuffle(pred_kpts) # shuffle predicted keypoints
+
+ indv_names = ["indv1", "indv2"]
+ return pred_kpts, gt_kpts, indv_names
+
+
+def test_invalid_rmse(animals_and_keypoints_invalid: tuple) -> None:
+ """Summary:
+ Tets if an invalid output really returns a ValueError in the rmse function.
+
+ Args:
+ animals_and_keypoints_invalid (tuple): containing predicted keypoints (pred_kpts),
+ ground truth keypoints (gt_kpts) and individual names (indv_names).
+ """
+ pred_kpts, gt_kpts, indv_names = animals_and_keypoints_invalid
+
+ with pytest.raises(ValueError):
+ deeplabcut_torch_match_predictions_gt.rmse_match_prediction_to_gt(
+ pred_kpts, gt_kpts
+ )
+
+
+def test_invalid_oks(animals_and_keypoints_invalid: tuple) -> None:
+ """Summary:
+ Test if an invalid output really returns a ValueError in the oks function.
+
+ Args:
+ animals_and_keypoints_invalid (tuple): containing predicted keypoints (pred_kpts), ground truth keypoints (gt_kpts)
+ and individual names (indv_names)
+ """
+ pred_kpts, gt_kpts, indv_names = animals_and_keypoints_invalid
+
+ with pytest.raises(ValueError):
+ deeplabcut_torch_match_predictions_gt.oks_match_prediction_to_gt(
+ pred_kpts, gt_kpts, indv_names
+ )
+
+
+def test_rmse_match_predictions_to_gt(
+ animals_and_keypoints: tuple, num_animals: int = 6
+) -> None:
+ """Summary:
+ Test if rmse_match_prediction_to_gt function returns the expected shape output.
+
+ Args:
+ animals_and_keypoints (tuple): containing predicted keypoints (pred_kpts), ground truth keypoints (gt_kpts)
+ and individual names (indv_names)
+ """
+ pred_kpts, gt_kpts, indv_names = animals_and_keypoints
+
+ col_ind = deeplabcut_torch_match_predictions_gt.rmse_match_prediction_to_gt(
+ pred_kpts, gt_kpts
+ )
+ assert isinstance(col_ind, np.ndarray)
+ assert col_ind.shape == (num_animals,)
+
+
+def test_oks_match_predictions_to_gt(
+ animals_and_keypoints: tuple, num_animals: int = 6
+) -> None:
+ """Summary:
+ Test if oks_match_predictions_to_gt function returns the expected shape output.
+
+ Args:
+ animals_and_keypoints (tuple): containing predicted keypoints (pred_kpts), ground truth keypoints (gt_kpts)
+ and individual names (indv_names)
+ """
+ pred_kpts, gt_kpts, indv_names = animals_and_keypoints
+
+ col_ind = deeplabcut_torch_match_predictions_gt.rmse_match_prediction_to_gt(
+ pred_kpts, gt_kpts
+ )
+ assert isinstance(col_ind, np.ndarray)
+ assert col_ind.shape == (num_animals,)
+
+
+def test_extend_col_ind(animals_and_keypoints: tuple, num_animals: int = 6) -> None:
+ """Summary:
+ Test if the column indices have the expected shape.
+
+ Args:
+ animals_and_keypoints (tuple): containing predicted keypoints (pred_kpts), ground truth keypoints (gt_kpts)
+ and individual names (indv_names)
+ """
+ pred_kpts, gt_kpts, indv_names = animals_and_keypoints
+
+ col_ind = deeplabcut_torch_match_predictions_gt.rmse_match_prediction_to_gt(
+ pred_kpts, gt_kpts
+ )
+ extended_array = deeplabcut_torch_match_predictions_gt.extend_col_ind(
+ col_ind, num_animals
+ )
+ assert extended_array.shape == (num_animals,)
diff --git a/tests/pose_estimation_pytorch/other/test_modelzoo.py b/tests/pose_estimation_pytorch/other/test_modelzoo.py
new file mode 100644
index 0000000000..f4ed80f5c8
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_modelzoo.py
@@ -0,0 +1,49 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import os
+
+import pytest
+
+from deeplabcut.modelzoo.video_inference import video_inference_superanimal
+from deeplabcut.utils import auxiliaryfunctions
+
+examples_folder = os.path.join(
+ auxiliaryfunctions.get_deeplabcut_path(),
+ "modelzoo",
+ "examples",
+)
+
+# requires videos to be in the examples folder
+@pytest.mark.skip
+@pytest.mark.parametrize(
+ "video_paths, superanimal_name",
+ [
+ (f"{examples_folder}/black_dog.mp4", "superanimal_quadruped"),
+ (f"{examples_folder}/black_dog.mp4", "superanimal_quadruped_hrnetw32"),
+ (f"{examples_folder}/swear_mouse_tiny.mp4", "superanimal_topviewmouse"),
+ (
+ f"{examples_folder}/swear_mouse_tiny.mp4",
+ "superanimal_topviewmouse_hrnetw32",
+ ),
+ ],
+)
+def test_video_inference_saves_file(video_paths, superanimal_name):
+ video_inference_superanimal(
+ video_paths,
+ superanimal_name=superanimal_name,
+ )
+ if isinstance(video_paths, str):
+ video_paths = [video_paths]
+ for video_path in video_paths:
+ output_path = video_path.replace(".mp4", f"_labeled.mp4")
+ assert os.path.exists(output_path), "Output video file does not exist"
+
+ assert os.stat(output_path).st_size > 0, "Output video file is empty"
diff --git a/tests/pose_estimation_pytorch/other/test_paf_targets.py b/tests/pose_estimation_pytorch/other/test_paf_targets.py
new file mode 100644
index 0000000000..f01fc3275a
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_paf_targets.py
@@ -0,0 +1,41 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import pytest
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators import pafs_targets
+
+
+@pytest.mark.parametrize(
+ "batch_size, num_keypoints, image_size",
+ [(2, 2, (64, 64)), (1, 5, (48, 64)), (8, 50, (64, 48))],
+)
+def test_paf_target_generation(
+ batch_size: int, num_keypoints: int, image_size: tuple, num_animals=2
+):
+ labels = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_keypoints, 2)
+ )
+ } # 2 for x,y coords
+ graph = [(i, j) for i in range(num_keypoints) for j in range(i + 1, num_keypoints)]
+ prediction = {
+ "heatmap": torch.rand((batch_size, num_keypoints, image_size[0], image_size[1])),
+ "paf": torch.rand((batch_size, len(graph) * 2, image_size[0], image_size[1])),
+ }
+ generator = pafs_targets.PartAffinityFieldGenerator(graph=graph, width=20)
+ targets_output = generator(1, prediction, labels)
+ assert targets_output["paf"]["target"].shape == (
+ batch_size,
+ len(graph) * 2,
+ image_size[0],
+ image_size[1],
+ )
diff --git a/tests/pose_estimation_pytorch/other/test_pose_model.py b/tests/pose_estimation_pytorch/other/test_pose_model.py
new file mode 100644
index 0000000000..977cbbb88e
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_pose_model.py
@@ -0,0 +1,306 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import copy
+import random
+
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.models as dlc_models
+from deeplabcut.pose_estimation_pytorch.models import CRITERIONS, TARGET_GENERATORS, PREDICTORS
+from deeplabcut.pose_estimation_pytorch.models.criterions import LOSS_AGGREGATORS
+from deeplabcut.pose_estimation_pytorch.models.modules import AdaptBlock, BasicBlock
+
+backbones_dicts = [
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w32",
+ "output_channels": 480,
+ "stride": 4,
+ "interpolate_branches": True,
+ },
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w18",
+ "output_channels": 270,
+ "stride": 4,
+ "interpolate_branches": True,
+ },
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w48",
+ "output_channels": 720,
+ "stride": 4,
+ "interpolate_branches": True,
+ },
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w32",
+ "output_channels": 32,
+ "interpolate_branches": False,
+ "increased_channel_count": False,
+ "stride": 4,
+ },
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w18",
+ "output_channels": 18,
+ "interpolate_branches": False,
+ "increased_channel_count": False,
+ "stride": 4,
+ },
+ {
+ "type": "HRNet",
+ "model_name": "hrnet_w48",
+ "output_channels": 48,
+ "interpolate_branches": False,
+ "increased_channel_count": False,
+ "stride": 4,
+ },
+ {"type": "ResNet", "model_name": "resnet50_gn", "output_channels": 2048, "stride": 32},
+]
+
+heads_dicts = [
+ {
+ "type": "HeatmapHead",
+ "predictor": {
+ "type": "HeatmapPredictor",
+ "location_refinement": True,
+ "locref_std": 7.2801,
+ },
+ "target_generator": {
+ "type": "HeatmapPlateauGenerator",
+ "num_heatmaps": "num_bodyparts",
+ "pos_dist_thresh": 17,
+ "heatmap_mode": "KEYPOINT",
+ "generate_locref": True,
+ "locref_std": 7.2801,
+ },
+ "criterion": {
+ "heatmap": {
+ "type": "WeightedBCECriterion",
+ "weight": 1.0,
+ },
+ "locref": {
+ "type": "WeightedHuberCriterion",
+ "weight": 0.05,
+ },
+ },
+ "heatmap_config": {
+ "channels": [2048, 1024, -1],
+ "kernel_size": [2, 2],
+ "strides": [2, 2],
+ },
+ "locref_config": {
+ "channels": [2048, 1024, -1],
+ "kernel_size": [2, 2],
+ "strides": [2, 2],
+ },
+ "output_channels": -1,
+ "input_channels": 2048,
+ "total_stride": 4,
+ },
+ {
+ "type": "TransformerHead",
+ "predictor": {
+ "type": "HeatmapPredictor",
+ "location_refinement": False,
+ },
+ "target_generator": {
+ "type": "HeatmapPlateauGenerator",
+ "num_heatmaps": "num_bodyparts",
+ "pos_dist_thresh": 17,
+ "heatmap_mode": "KEYPOINT",
+ "generate_locref": False,
+ },
+ "criterion": {"type": "WeightedBCECriterion"},
+ "dim": 192,
+ "hidden_heatmap_dim": 384,
+ "heatmap_dim": -1,
+ "apply_multi": True,
+ "heatmap_size": [-1, -1],
+ "apply_init": True,
+ "total_stride": 1,
+ "input_channels": -1,
+ "output_channels": -1,
+ "head_stride": 1,
+ },
+ {
+ "type": "DEKRHead",
+ "predictor": {
+ "type": "DEKRPredictor",
+ "num_animals": 1,
+ "keypoint_score_type": "heatmap",
+ "max_absorb_distance": 75,
+ },
+ "target_generator": {
+ "type": "DEKRGenerator",
+ "num_joints": "num_bodyparts",
+ "pos_dist_thresh": 17,
+ "bg_weight": 0.1,
+ },
+ "criterion": {
+ "heatmap": {
+ "type": "WeightedBCECriterion",
+ "weight": 1.0,
+ },
+ "offset": {
+ "type": "WeightedHuberCriterion",
+ "weight": 0.03,
+ },
+ },
+ "heatmap_config": {
+ "channels": [480, 64, -1],
+ "num_blocks": 1,
+ "dilation_rate": 1,
+ "final_conv_kernel": 1,
+ "block": BasicBlock,
+ },
+ "offset_config": {
+ "channels": [480, -1, -1],
+ "num_offset_per_kpt": 15,
+ "num_blocks": 1,
+ "dilation_rate": 1,
+ "final_conv_kernel": 1,
+ "block": AdaptBlock,
+ },
+ "total_stride": 1,
+ "input_channels": 480,
+ "output_channels": -1,
+ },
+]
+
+
+def _generate_random_backbone_inputs(i):
+ # Returns sizes that are divisible by 64to be able to predict consistently output size
+ # (and be able to do the forward pass of HRNet)
+ x_size_tmp, y_size_tmp = random.randint(100, 1000), random.randint(100, 1000)
+ return (
+ backbones_dicts[i],
+ (x_size_tmp - x_size_tmp % 64, y_size_tmp - y_size_tmp % 64),
+ )
+
+
+@pytest.mark.parametrize(
+ "backbone_dict, input_size",
+ [_generate_random_backbone_inputs(i) for i in range(len(backbones_dicts))],
+)
+def test_backbone(backbone_dict, input_size):
+ input_tensor = torch.Tensor(1, 3, input_size[1], input_size[0])
+
+ stride = backbone_dict.pop("stride")
+ output_channels = backbone_dict.pop("output_channels")
+ backbone = dlc_models.BACKBONES.build(backbone_dict)
+
+ features = backbone(input_tensor)
+ _, c, h, w = features.shape
+ assert c == output_channels
+ assert h == input_size[1] // stride
+ assert w == input_size[0] // stride
+
+
+def _generate_random_head_inputs(i):
+ # Returns sizes that are divisible by 64to be able to predict consistently output size
+ # (and be able to do the forward pass of HRNet)
+ x_size_tmp, y_size_tmp = random.randint(8, 500), random.randint(8, 500)
+ num_kpts = random.randint(2, 50)
+ return (
+ heads_dicts[i],
+ (x_size_tmp - x_size_tmp % 4, y_size_tmp - y_size_tmp % 4),
+ num_kpts,
+ )
+
+
+@pytest.mark.parametrize(
+ "head_dict, input_shape, num_keypoints",
+ [_generate_random_head_inputs(i) for i in range(len(heads_dicts))],
+)
+def test_head(head_dict, input_shape, num_keypoints):
+ w, h = input_shape
+ head_dict = copy.deepcopy(head_dict)
+
+ head_type = head_dict["type"]
+ input_channels = head_dict.pop("input_channels")
+ output_channels = head_dict.pop("output_channels")
+ total_stride = head_dict.pop("total_stride")
+ if head_type == "HeatmapHead":
+ output_channels = num_keypoints
+ head_dict["heatmap_config"]["channels"][2] = output_channels
+ head_dict["locref_config"]["channels"][2] = 2 * output_channels
+ head_dict["target_generator"]["num_heatmaps"] = output_channels
+ input_tensor = torch.zeros((1, input_channels, h, w))
+
+ elif head_type == "TransformerHead":
+ output_channels = num_keypoints
+ input_channels = num_keypoints
+ head_dict["heatmap_dim"] = h * w
+ head_dict["heatmap_size"] = [h, w]
+ head_dict["target_generator"]["num_heatmaps"] = output_channels
+ input_tensor = torch.zeros((1, input_channels, head_dict["dim"] * 3))
+
+ elif head_type == "DEKRHead":
+ output_channels = num_keypoints + 1
+ head_dict["target_generator"]["num_joints"] = num_keypoints
+ head_dict["heatmap_config"]["channels"][2] = num_keypoints + 1
+ head_dict["offset_config"]["channels"][1] = (
+ num_keypoints * head_dict["offset_config"]["num_offset_per_kpt"]
+ )
+ head_dict["offset_config"]["channels"][2] = num_keypoints
+ input_tensor = torch.zeros((1, input_channels, h, w))
+
+ if "type" in head_dict["criterion"]:
+ head_dict["criterion"] = CRITERIONS.build(head_dict["criterion"])
+ else:
+ weights = {}
+ criterions = {}
+ for loss_name, criterion_cfg in head_dict["criterion"].items():
+ weights[loss_name] = criterion_cfg.get("weight", 1.0)
+ criterion_cfg = {
+ k: v for k, v in criterion_cfg.items() if k != "weight"
+ }
+ criterions[loss_name] = CRITERIONS.build(criterion_cfg)
+
+ aggregator_cfg = {"type": "WeightedLossAggregator", "weights": weights}
+ head_dict["aggregator"] = LOSS_AGGREGATORS.build(aggregator_cfg)
+ head_dict["criterion"] = criterions
+
+ head_dict["target_generator"] = TARGET_GENERATORS.build(
+ head_dict["target_generator"]
+ )
+ head_dict["predictor"] = PREDICTORS.build(head_dict["predictor"])
+ head = dlc_models.HEADS.build(head_dict)
+
+ output = head(input_tensor)["heatmap"]
+ _, c_out, h_out, w_out = output.shape
+ assert (h_out == h * total_stride) and (w_out == w * total_stride)
+ assert c_out == output_channels
+
+
+def test_msa_hrnet():
+ # TODO: build microsoft asia hrnet and check dimension of output
+ # TODO: check if hyperparameters are loaded correctly (from the config file)
+ pass
+
+
+def test_msa_tokenpose():
+ # TODO: build microsoft asia hrnet and check dimension of output
+ # TODO: check if hyperparameters are loaded correctly (from the config file)
+ # cf https://github.com/amathislab/BUCTDdev/blob/main/lib/models/transpose_h.py#L1
+ pass
+
+
+def test_msa_hrnetCOAM():
+ # TODO: build BUCTD COAM hrnet and check dimension of output
+ # TODO: check if hyperparameters are loaded correctly (from the config file)
+ pass
+
+
+# TODO: add other model variants our pipeline can build ;)
diff --git a/tests/pose_estimation_pytorch/other/test_seq_targets.py b/tests/pose_estimation_pytorch/other/test_seq_targets.py
new file mode 100644
index 0000000000..82f931f520
--- /dev/null
+++ b/tests/pose_estimation_pytorch/other/test_seq_targets.py
@@ -0,0 +1,55 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from itertools import combinations
+
+import torch
+
+from deeplabcut.pose_estimation_pytorch.models.target_generators import (
+ TARGET_GENERATORS,
+)
+
+
+def test_sequential_generator():
+ batch_size = 4
+ image_size = 256, 256
+ num_keypoints = 12
+ num_animals = 2
+ graph = [list(edge) for edge in combinations(range(num_keypoints), 2)]
+ num_limbs = len(graph)
+ cfg = {
+ "type": "SequentialGenerator",
+ "generators": [
+ {
+ "type": "HeatmapPlateauGenerator",
+ "num_heatmaps": num_keypoints,
+ "pos_dist_thresh": 17,
+ "generate_locref": True,
+ "locref_std": 7.2801,
+ },
+ {"type": "PartAffinityFieldGenerator", "graph": graph, "width": 20},
+ ],
+ }
+ gen = TARGET_GENERATORS.build(cfg)
+
+ annotations = {
+ "keypoints": torch.randint(
+ 1, min(image_size), (batch_size, num_animals, num_keypoints, 2)
+ )
+ }
+ head_outputs = {
+ "heatmap": torch.rand(batch_size, num_keypoints, 32, 32),
+ "locref": torch.rand(batch_size, num_keypoints * 2, 32, 32),
+ "paf": torch.rand(batch_size, num_limbs * 2, 32, 32),
+ }
+ out = gen(stride=1, outputs=head_outputs, labels=annotations)
+ assert all(s in out for s in list(head_outputs))
+ for k, v in head_outputs.items():
+ assert out[k]["target"].shape == v.shape
diff --git a/tests/pose_estimation_pytorch/post_processing/test_identity.py b/tests/pose_estimation_pytorch/post_processing/test_identity.py
new file mode 100644
index 0000000000..42ae454341
--- /dev/null
+++ b/tests/pose_estimation_pytorch/post_processing/test_identity.py
@@ -0,0 +1,58 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+""" Tests identity matching """
+import numpy as np
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.post_processing.identity import assign_identity
+
+
+@pytest.mark.parametrize(
+ "prediction, identity_scores, output_order",
+ [
+ (
+ [
+ [[0, 0, 1.0], [0, 0, 1.0]], # assembly 1
+ [[5, 5, 1.0], [5, 5, 1.0]], # assembly 2
+ [[9, 9, 1.0], [9, 9, 1.0]], # assembly 3
+ ],
+ [ # a0 -> idv1, a1 -> idv2, a2 -> idv0
+ [[0.1, 0.8, 0.3], [0.1, 0.7, 0.3]], # assembly 1 ID scores
+ [[0.2, 0.1, 0.6], [0.3, 0.1, 0.5]], # assembly 2 ID scores
+ [[0.7, 0.1, 0.1], [0.6, 0.2, 0.2]], # assembly 3 ID scores
+ ],
+ [2, 0, 1],
+ ),
+ (
+ [
+ [[0, 0, 1.0], [0, 0, 1.0]], # assembly 1
+ [[1, 1, 1.0], [5, 5, 1.0]], # assembly 2
+ [[0, 0, 1.0], [9, 9, 1.0]], # assembly 3
+ ],
+ [ # a0 -> idv0, a1 -> idv1, a2 -> idv2
+ [[0.4, 0.4, 0.3], [0.5, 0.3, 0.3]], # assembly 1 ID scores
+ [[0.4, 0.4, 0.3], [0.3, 0.5, 0.4]], # assembly 2 ID scores
+ [[0.2, 0.2, 0.4], [0.2, 0.2, 0.3]], # assembly 3 ID scores
+ ],
+ [0, 1, 2],
+ ),
+ ],
+)
+def test_single_identity_assignment(prediction, identity_scores, output_order):
+ predictions = np.array(prediction)
+ identity_scores = np.array(identity_scores)
+ new_order = assign_identity(predictions, identity_scores)
+ predictions_with_id = predictions[new_order]
+
+ print()
+ print(predictions.shape)
+ print(identity_scores.shape)
+ np.testing.assert_equal(predictions[output_order], predictions_with_id)
diff --git a/tests/pose_estimation_pytorch/runners/bottum_up.py b/tests/pose_estimation_pytorch/runners/bottum_up.py
new file mode 100644
index 0000000000..821c11422d
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/bottum_up.py
@@ -0,0 +1,104 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+""" Tests for the bottom-up pytorch runner """
+from pathlib import Path
+from typing import Dict, Any
+
+import pytest
+import torch
+from deeplabcut.pose_estimation_pytorch.config import make_pytorch_pose_config
+
+from deeplabcut.pose_estimation_pytorch.models import PoseModel, LOSSES, PREDICTORS
+from deeplabcut.pose_estimation_pytorch.models.criterion import WeightedAggregateLoss
+from deeplabcut.pose_estimation_pytorch.runners import RUNNERS
+from deeplabcut.pose_estimation_pytorch.runners.schedulers import LRListScheduler
+from deeplabcut.utils import auxiliaryfunctions
+
+
+SINGLE_ANIMAL_NETS = ["resnet_50"]
+MULTI_ANIMAL_NETS = ["dekr_w18"]
+NETS = [(n, False) for n in SINGLE_ANIMAL_NETS] + [(n, True) for n in MULTI_ANIMAL_NETS]
+
+
+def print_dict(data: Dict, indent: int = 0):
+ for k, v in data.items():
+ if isinstance(v, dict):
+ print_dict(v, indent=indent + 2)
+ else:
+ print(f"{indent * ' '}{k}: {v}")
+
+
+@pytest.mark.parametrize("net_type, multianimal", NETS)
+def test_build_bottom_up_runner(
+ net_type: str,
+ multianimal: bool,
+) -> None:
+ project_cfg: Dict[str, Any] = {"multianimalproject": multianimal}
+ if multianimal:
+ project_cfg["bodyparts"] = "MULTI!"
+ project_cfg["multianimalbodyparts"] = ["head", "shoulder", "knee", "toe"]
+ project_cfg["uniquebodyparts"] = []
+ project_cfg["individuals"] = ["tom", "jerry"]
+ else:
+ project_cfg["bodyparts"] = ["head", "shoulder", "knee", "toe"]
+ project_cfg["uniquebodyparts"] = []
+ project_cfg["individuals"] = ["tom"]
+
+ root_path = Path(auxiliaryfunctions.get_deeplabcut_path())
+ template_path = root_path / "pose_estimation_pytorch" / "apis" / "pytorch_config.yaml"
+ template = auxiliaryfunctions.read_plainconfig(str(template_path))
+ pytorch_cfg = make_pytorch_pose_config(project_cfg, str(template_path), net_type)
+ print_dict(pytorch_cfg)
+
+ pose_model = PoseModel.build(pytorch_cfg["model"])
+
+ head_criterions = []
+ for head_cfg in pytorch_cfg["model"]["heads"]:
+ crit_cfg = head_cfg["criterion"]
+ criterion_weight = crit_cfg.get("weight", 1)
+ criterion = LOSSES.build({k: v for k, v in crit_cfg.items() if k != "weight"})
+ head_criterions.append((criterion_weight, criterion))
+ criterion = WeightedAggregateLoss(head_criterions)
+
+ get_optimizer = getattr(torch.optim, pytorch_cfg["optimizer"]["type"])
+ optimizer = get_optimizer(
+ params=pose_model.parameters(), **pytorch_cfg["optimizer"]["params"]
+ )
+
+ predictor = PREDICTORS.build(dict(pytorch_cfg["model"]["predictor"]))
+
+ if pytorch_cfg.get("scheduler"):
+ if pytorch_cfg["scheduler"]["type"] == "LRListScheduler":
+ _scheduler = LRListScheduler
+ else:
+ _scheduler = getattr(
+ torch.optim.lr_scheduler, pytorch_cfg["scheduler"]["type"]
+ )
+ scheduler = _scheduler(
+ optimizer=optimizer, **pytorch_cfg["scheduler"]["params"]
+ )
+ else:
+ scheduler = None
+
+ logger = None
+ runner = RUNNERS.build(
+ dict(
+ **pytorch_cfg["solver"],
+ model=pose_model,
+ criterion=criterion,
+ optimizer=optimizer,
+ predictor=predictor,
+ cfg=pytorch_cfg,
+ device=pytorch_cfg["device"],
+ scheduler=scheduler,
+ logger=logger,
+ )
+ )
diff --git a/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py b/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py
new file mode 100644
index 0000000000..04b5b9a008
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_dynamic_cropper.py
@@ -0,0 +1,145 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests the dynamic cropper"""
+import pytest
+
+import numpy as np
+import torch
+
+from deeplabcut.pose_estimation_pytorch.runners.dynamic_cropping import DynamicCropper
+
+
+@pytest.mark.parametrize("dynamic", [(False, 0.5, 10)])
+def test_build_dynamic_cropper(dynamic: tuple[bool, float, int]):
+ cropper = DynamicCropper.build(*dynamic)
+ should_be_built, threshold, margin = dynamic
+ if should_be_built:
+ assert isinstance(cropper, DynamicCropper)
+ assert cropper.threshold == threshold
+ assert cropper.margin == margin
+ else:
+ assert cropper is None
+
+
+@pytest.mark.parametrize("batch_size", [0, 2, 8])
+def test_dynamic_fails_with_image_batch(batch_size: int):
+ cropper = DynamicCropper(threshold=0.6, margin=10)
+ with pytest.raises(RuntimeError):
+ cropper.crop(torch.zeros(batch_size, 3, 128, 128))
+
+
+def test_dynamic_fails_with_variable_frame_size():
+ cropper = DynamicCropper(threshold=0.6, margin=10)
+ cropper.crop(torch.zeros(1, 3, 64, 64))
+ with pytest.raises(RuntimeError):
+ cropper.crop(torch.zeros(1, 3, 128, 128))
+
+
+def test_dynamic_fails_with_update_before_crop():
+ cropper = DynamicCropper(threshold=0.6, margin=10)
+ with pytest.raises(RuntimeError):
+ cropper.update(torch.ones(5, 17, 3))
+
+
+@pytest.mark.parametrize("threshold", [0.25, 0.5, 0.8])
+def test_dynamic_cropper_does_nothing_with_low_quality(threshold: float):
+ cropper = DynamicCropper(threshold=threshold, margin=10)
+ image_in = torch.ones((1, 3, 32, 32))
+ cropper.crop(image_in)
+ for i in range(10):
+ pose = _generate_random_pose(
+ (32, 64),
+ min_score=0.0,
+ max_score=threshold - 0.001,
+ seed=i,
+ )
+ cropper.update(pose)
+ image_out = cropper.crop(image_in)
+ assert torch.equal(image_in, image_out)
+
+
+@pytest.mark.parametrize(
+ "pose, threshold, margin, expected_crop",
+ [
+ ([[float("nan"), float("nan"), float("nan")]], 0.1, 10, [0, 0, 64, 64]),
+ ([[float("nan"), 30, 0.0]], 0.5, 10, [0, 0, 64, 64]),
+ ([[20, 30, 0.0]], 0.5, 10, [0, 0, 64, 64]),
+ ([[20, 30, 0.49]], 0.5, 10, [0, 0, 64, 64]),
+ ([[20, 30, 0.8]], 0.5, 10, [10, 20, 30, 40]),
+ ([[20, 30, 0.8], [float("nan"), float("nan"), 0.2]], 0.5, 15, [5, 15, 35, 45]),
+ ([[20, 30, 0.8], [5, 5, 0.2]], 0.5, 15, [0, 0, 35, 45]),
+ ([[20, 30, 0.8], [35, 30, 0.79]], 0.8, 5, [15, 25, 40, 35]),
+ ([[40, 10, 0.2], [35, 15, 0.79]], 0.3, 8, [27, 2, 48, 23]),
+ (
+ [
+ [[float("nan"), float("nan"), float("nan")]],
+ [[float("nan"), float("nan"), float("nan")]],
+ ],
+ 0.15, 10, [0, 0, 64, 64]
+ ),
+ (
+ [
+ [[20, 30, 0.8], [5, 12, 0.2]],
+ [[40, 10, 0.2], [35, 15, 0.79]],
+ ],
+ 0.15, 5, [0, 5, 45, 35]
+ ),
+ ],
+)
+def test_dynamic_cropper_basic_crop(
+ pose: list[list[float]],
+ threshold: float,
+ margin: int,
+ expected_crop: tuple[int, int, int, int]
+) -> None:
+ x0, y0, x1, y1 = expected_crop
+ crop_w, crop_h = x1 - x0, y1 - y0
+
+ image_in = torch.zeros((1, 3, 64, 64))
+ image_in[:, :, y0:y1, x0:x1] = 1
+ expected_image_out = torch.ones((1, 3, crop_h, crop_w))
+
+ cropper = DynamicCropper(threshold=threshold, margin=margin)
+ image_out = cropper.crop(image_in)
+ assert torch.equal(image_out, image_in)
+
+ cropper.update(torch.tensor(pose))
+ image_out = cropper.crop(image_in)
+ assert image_out.shape == expected_image_out.shape
+ assert torch.equal(image_out, expected_image_out)
+
+ pose_out = torch.tensor(pose)
+ print("\nPose in")
+ print(pose_out.numpy())
+ pose_out[..., 0] -= x0
+ pose_out[..., 1] -= y0
+ print("Pose out before update")
+ print(pose_out.numpy())
+ cropper.update(pose_out)
+ print("Pose out after update")
+ print(pose_out.numpy())
+ np.testing.assert_allclose(pose_out.numpy(), np.array(pose))
+
+
+def _generate_random_pose(
+ image_shape: tuple[int, int],
+ min_score: float,
+ max_score: float,
+ num_animals: int = 3,
+ num_keypoints: int = 7,
+ seed: int = 0,
+) -> torch.Tensor:
+ gen = np.random.default_rng(seed)
+ pose = gen.random((num_animals, num_keypoints, 3))
+ pose[..., 0] *= image_shape[0]
+ pose[..., 1] *= image_shape[1]
+ pose[..., 2] = (pose[..., 2] * (max_score - min_score)) + min_score
+ return torch.from_numpy(pose)
diff --git a/tests/pose_estimation_pytorch/runners/test_logger.py b/tests/pose_estimation_pytorch/runners/test_logger.py
new file mode 100644
index 0000000000..6f40615e07
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_logger.py
@@ -0,0 +1,75 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests loggers"""
+from typing import Any
+
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.runners.logger as logging
+
+
+class MockImageLogger(logging.ImageLoggerMixin):
+ """Mock image logger"""
+
+ def log_images(
+ self,
+ inputs: dict[str, Any],
+ outputs: dict[str, torch.Tensor],
+ targets: dict[str, dict[str, torch.Tensor]],
+ step: int,
+ ) -> None:
+ pass
+
+
+@pytest.mark.parametrize(
+ "keypoints",
+ [
+ [
+ [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]],
+ ],
+ [
+ [[float("nan"), float("nan")], [float("nan"), float("nan")]],
+ ],
+ [
+ [[0.0, 0.0], [1, 1], [2, 2]],
+ ],
+ [[[float("nan"), 0.0], [1, 1], [2, 2]]],
+ [[[-1.0, -1.0], [1, 1], [2, 2]]],
+ [
+ [[-1.0, -1.0], [-1.0, -1.0]],
+ ],
+ [
+ [[-1.0, -1.0], [-1.0, -1.0]],
+ [[1.0, 1.0], [1.0, 1.0]],
+ ],
+ ],
+)
+@pytest.mark.parametrize("denormalize", [True, False])
+def test_prepare_image(keypoints: list[list[float]], denormalize: bool) -> None:
+ image = torch.ones((3, 256, 256))
+ keypoints = torch.tensor(keypoints)
+
+ print()
+ print(f"IMAGE: {image.shape}")
+ print(f"KEYPOINTS: {keypoints.shape}")
+ for k in keypoints:
+ print(k)
+ print()
+ print()
+
+ logger = MockImageLogger()
+ logger._prepare_image(
+ image=image,
+ denormalize=denormalize,
+ keypoints=keypoints,
+ bboxes=None,
+ )
diff --git a/tests/pose_estimation_pytorch/runners/test_runners.py b/tests/pose_estimation_pytorch/runners/test_runners.py
new file mode 100644
index 0000000000..3f2fc2e3da
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_runners.py
@@ -0,0 +1,40 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import pickle
+from pathlib import Path
+from unittest.mock import Mock
+
+import numpy as np
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.runners as runners
+
+
+@pytest.mark.parametrize("value", [True, False])
+def test_set_load_weights_only(value: bool):
+ print(f"\nget_load_weights_only: {runners.get_load_weights_only()}")
+ print(f"setting value to {value}")
+ runners.set_load_weights_only(value)
+ print(f"get_load_weights_only: {runners.get_load_weights_only()}\n")
+ assert runners.get_load_weights_only() == value
+
+
+def test_load_snapshot_weights_only_error(tmpdir_factory):
+ snapshot_dir = Path(tmpdir_factory.mktemp("snapshot-dir"))
+ snapshot_path = snapshot_dir / "snapshot.pt"
+ torch.save(dict(content=np.zeros(10)), str(snapshot_path))
+
+ runners.set_load_weights_only(False)
+ with pytest.raises(pickle.UnpicklingError):
+ runners.Runner.load_snapshot(
+ snapshot_path, device="cpu", model=Mock(), weights_only=True
+ )
diff --git a/tests/pose_estimation_pytorch/runners/test_runners_inference.py b/tests/pose_estimation_pytorch/runners/test_runners_inference.py
new file mode 100644
index 0000000000..272b5eb509
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_runners_inference.py
@@ -0,0 +1,190 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests inference runners"""
+from unittest.mock import Mock, patch
+
+import numpy as np
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.data.postprocessor as post
+import deeplabcut.pose_estimation_pytorch.data.preprocessor as prep
+import deeplabcut.pose_estimation_pytorch.runners.inference as inference
+from deeplabcut.pose_estimation_pytorch import get_load_weights_only
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.build_optimizer", Mock())
+@pytest.mark.parametrize("task", [Task.DETECT, Task.TOP_DOWN, Task.BOTTOM_UP])
+@pytest.mark.parametrize("weights_only", [None, True, False])
+def test_load_weights_only_with_build_training_runner(task: Task, weights_only: bool):
+ with patch("deeplabcut.pose_estimation_pytorch.runners.base.torch.load") as load:
+ snapshot = "snapshot.pt"
+ runner = inference.build_inference_runner(
+ task=task,
+ model=Mock(),
+ device="cpu",
+ snapshot_path=snapshot,
+ load_weights_only=weights_only,
+ )
+ if weights_only is None:
+ weights_only = get_load_weights_only()
+ load.assert_called_once_with(
+ snapshot, map_location="cpu", weights_only=weights_only
+ )
+
+
+class MockInferenceRunner(inference.InferenceRunner):
+ """Mocks the predict function for an inference runner"""
+
+ def __init__(
+ self,
+ batch_size: int = 1,
+ preprocessor: prep.Preprocessor | None = None,
+ postprocessor: post.Postprocessor | None = None,
+ ) -> None:
+ super().__init__(
+ model=Mock(),
+ batch_size=batch_size,
+ preprocessor=preprocessor,
+ postprocessor=postprocessor,
+ )
+ self.batch_shapes = []
+
+ def predict(self, inputs: torch.Tensor) -> list[dict[str, dict[str, np.ndarray]]]:
+ self.batch_shapes.append(tuple(inputs.shape))
+ return [ # return first elem of input
+ {"mock": {"index": i[0, 0, 0].detach().numpy()}} for i in inputs
+ ]
+
+
+@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
+def test_mock_bottom_up(batch_size):
+ h, w = 640, 480
+ images = [i * np.ones((1, 3, h, w)) for i in range(10)]
+
+ runner = MockInferenceRunner(batch_size=batch_size)
+ predictions = runner.inference(images)
+
+ print()
+ print(f"Num images: {len(predictions)}")
+ print(f"Num predictions: {len(predictions)}")
+ print(f"Batch shapes: {runner.batch_shapes}")
+ print(80 * "-")
+ for i in images:
+ print(i[0, 0, 0, 0])
+ print("----")
+ print(80 * "-")
+ for p in predictions:
+ print(p)
+ print("----")
+
+ _check_batch_shapes(batch_size, h, w, runner.batch_shapes)
+ assert len(images) == len(predictions)
+ for i, p in zip(images, predictions):
+ assert len(p) == 1 # only 1 output per image
+ assert i[0, 0, 0, 0] == p[0]["mock"]["index"]
+
+
+@pytest.mark.parametrize("batch_size", [1, 2, 4, 8])
+@pytest.mark.parametrize(
+ "detections_per_image",
+ [
+ [1, 1, 1, 1, 1],
+ [0, 1, 0, 1, 1], # some frames might not have predictions
+ [0, 0, 0, 5, 2],
+ [1, 2, 3, 4],
+ [3, 4, 2, 1, 4],
+ [4, 23, 5, 20, 64, 100],
+ ],
+)
+def test_mock_top_down(batch_size, detections_per_image):
+ h, w = 8, 8
+ images = []
+ for index, num_detections in enumerate(detections_per_image):
+ if num_detections == 0:
+ detections = np.zeros((0, 3, 1, 1)) # random shape when no detections
+ else:
+ detections = np.concatenate(
+ [
+ (1_000_000 * (index + 1) + i) * np.ones((1, 3, h, w))
+ for i in range(num_detections)
+ ],
+ axis=0,
+ )
+
+ images.append(detections)
+
+ runner = MockInferenceRunner(batch_size=batch_size)
+ predictions = runner.inference(images)
+
+ print()
+ print(f"Num images: {len(predictions)}")
+ print(f"Num predictions: {len(predictions)}")
+ print(80 * "-")
+ for i in images:
+ for i_det in i:
+ print(i_det.shape)
+ print(i_det[0, 0, 0])
+ print("----")
+
+ print(80 * "-")
+ for p in predictions:
+ print(p)
+ print("----")
+
+ _check_batch_shapes(batch_size, h, w, runner.batch_shapes)
+
+ assert len(images) == len(predictions)
+ for i, p in zip(images, predictions):
+ assert len(p) == len(i) # one prediction per input
+ for i_det, p_det in zip(i, p):
+ print(i_det.shape)
+ print(p_det["mock"]["index"])
+ assert i_det[0, 0, 0] == p_det["mock"]["index"]
+
+
+def test_dynamic_pose_inference_calls_dynamic():
+ pose_batch = [Mock()]
+ image_crop = Mock()
+ image_crop.__len__ = Mock(return_value=1)
+
+ model = Mock()
+ model.get_predictions = Mock()
+ model.get_predictions.return_value = dict(bodypart=dict(poses=pose_batch))
+
+ dynamic = Mock()
+ dynamic.crop = Mock()
+ dynamic.crop.return_value = image_crop
+ dynamic.update = Mock()
+
+ runner = inference.PoseInferenceRunner(
+ model=model,
+ dynamic=dynamic,
+ batch_size=1,
+ )
+ image = torch.Tensor((1, 3, 64, 64))
+ _ = runner.predict(image)
+ dynamic.crop.assert_called_once_with(image)
+ dynamic.update.assert_called_once_with(pose_batch)
+
+
+def _check_batch_shapes(batch_size, h, w, batch_shapes) -> None:
+ for b in batch_shapes[:-1]:
+ assert b[0] == batch_size
+ assert b[1] == 3
+ assert b[2] == h
+ assert b[3] == w
+
+ assert batch_shapes[-1][0] <= batch_size
+ assert batch_shapes[-1][1] <= 3
+ assert batch_shapes[-1][2] <= h
+ assert batch_shapes[-1][3] <= w
diff --git a/tests/pose_estimation_pytorch/runners/test_runners_train.py b/tests/pose_estimation_pytorch/runners/test_runners_train.py
new file mode 100644
index 0000000000..0fff6e671d
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_runners_train.py
@@ -0,0 +1,327 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+from dataclasses import dataclass
+from unittest.mock import Mock, patch
+
+import numpy as np
+import pytest
+import torch
+
+import deeplabcut.pose_estimation_pytorch.runners.schedulers as schedulers
+import deeplabcut.pose_estimation_pytorch.runners.train as train_runners
+from deeplabcut.pose_estimation_pytorch.models import PoseModel
+from deeplabcut.pose_estimation_pytorch.models.backbones import ResNet
+from deeplabcut.pose_estimation_pytorch.models.heads import HeatmapHead
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.build_optimizer", Mock())
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.CSVLogger", Mock())
+@pytest.mark.parametrize("task", [Task.DETECT, Task.TOP_DOWN, Task.BOTTOM_UP])
+@pytest.mark.parametrize("weights_only", [True, False])
+def test_load_weights_only_with_build_training_runner(task: Task, weights_only: bool):
+ runner_config = dict(
+ optimizer=dict(),
+ snapshots=dict(max_snapshots=1, save_epochs=5, save_optimizer_state=False),
+ load_weights_only=weights_only,
+ )
+ with patch("deeplabcut.pose_estimation_pytorch.runners.base.torch.load") as load:
+ train_runners.build_training_runner(
+ runner_config=runner_config,
+ model_folder=Mock(),
+ task=task,
+ model=Mock(),
+ device="cpu",
+ snapshot_path="snapshot.pt",
+ )
+ load.assert_called_once_with(
+ "snapshot.pt", map_location="cpu", weights_only=weights_only
+ )
+
+
+@dataclass
+class SchedulerTestConfig:
+ cfg: dict
+ init_lr: float
+ expected_lrs: list[float]
+
+
+TEST_SCHEDULERS = [
+ SchedulerTestConfig(
+ cfg=dict(
+ type="LRListScheduler",
+ params=dict(milestones=[2, 5], lr_list=[[0.5], [0.1]]),
+ ),
+ init_lr=1.0,
+ expected_lrs=[1.0, 1.0, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="LRListScheduler", params=dict(milestones=[1], lr_list=[[0.1]])),
+ init_lr=0.1,
+ expected_lrs=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="LRListScheduler", params=dict(milestones=[1], lr_list=[[0.5]])),
+ init_lr=0.1,
+ expected_lrs=[0.1, 0.5, 0.5, 0.5],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="StepLR", params=dict(step_size=3, gamma=0.1)),
+ init_lr=1.0,
+ expected_lrs=[1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.01, 0.01, 0.01, 0.001],
+ ),
+]
+
+
+@pytest.mark.parametrize("load_head_weights", [True, False])
+def test_load_head_weights(tmp_path_factory, load_head_weights):
+ model_folder = tmp_path_factory.mktemp("model_folder")
+ runner_config = dict(
+ optimizer=dict(type="SGD", params=dict(lr=1)),
+ snapshots=dict(max_snapshots=1, save_epochs=1, save_optimizer_state=False),
+ )
+
+ model = PoseModel(
+ cfg=dict(),
+ backbone=ResNet(),
+ heads=dict(
+ bodyparts=HeatmapHead(
+ predictor=Mock(),
+ target_generator=Mock(),
+ criterion=Mock(),
+ aggregator=None,
+ heatmap_config=dict(channels=[2048, 10], kernel_size=[3], strides=[2]),
+ ),
+ ),
+ )
+
+ original_state_dict = model.state_dict()
+ zero_state_dict = {
+ k: torch.zeros_like(v) for k, v in original_state_dict.items()
+ }
+
+ load = Mock()
+ load.return_value = dict(model=zero_state_dict)
+
+ with patch("deeplabcut.pose_estimation_pytorch.runners.train.torch.load", load):
+ r = train_runners.build_training_runner(
+ runner_config,
+ model_folder=model_folder,
+ task=Task.BOTTOM_UP,
+ model=model,
+ device="cpu",
+ snapshot_path=model_folder / "snapshot.pt",
+ load_head_weights=load_head_weights,
+ )
+ loaded_state_dict = r.model.state_dict()
+ for k, v in loaded_state_dict.items():
+ if load_head_weights or k.startswith("backbone."):
+ assert torch.equal(v, zero_state_dict[k])
+ else:
+ assert torch.equal(v, original_state_dict[k])
+
+
+@pytest.mark.parametrize("load_head_weights", [True, False])
+def test_mocked_load_head_weights(tmp_path_factory, load_head_weights):
+ model_folder = tmp_path_factory.mktemp("model_folder")
+ snapshot_manager = Mock()
+ snapshot_manager.model_folder = model_folder
+
+ model = Mock()
+ model.backbone = Mock()
+ state_dict = {"backbone.test": 0, "head.test": 1}
+ state_dict_backbone = {"test": 0}
+ load = Mock()
+ load.return_value = dict(model=state_dict)
+
+ with patch("deeplabcut.pose_estimation_pytorch.runners.train.torch.load", load):
+ _ = train_runners.PoseTrainingRunner(
+ model=model,
+ optimizer=Mock(),
+ snapshot_manager=snapshot_manager,
+ device="cpu",
+ snapshot_path="snapshot.pt",
+ load_head_weights=load_head_weights,
+ )
+ if load_head_weights:
+ model.load_state_dict.assert_called_once_with(state_dict)
+ else:
+ model.backbone.load_state_dict.assert_called_once_with(state_dict_backbone)
+
+
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.CSVLogger", Mock())
+@pytest.mark.parametrize(
+ "runner_cls",
+ [
+ train_runners.PoseTrainingRunner,
+ train_runners.DetectorTrainingRunner,
+ ],
+)
+@pytest.mark.parametrize("test_cfg", TEST_SCHEDULERS)
+def test_training_with_scheduler(runner_cls, test_cfg: SchedulerTestConfig) -> None:
+ runner = _fit_runner_and_check_lrs(
+ runner_cls,
+ test_cfg.init_lr,
+ test_cfg.cfg,
+ test_cfg.expected_lrs,
+ )
+ assert runner.current_epoch == len(test_cfg.expected_lrs)
+
+
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.CSVLogger", Mock())
+@pytest.mark.parametrize(
+ "runner_cls",
+ [
+ train_runners.PoseTrainingRunner,
+ train_runners.DetectorTrainingRunner,
+ ],
+)
+@pytest.mark.parametrize("test_cfg", TEST_SCHEDULERS)
+def test_resuming_training_scheduler_every_epoch(
+ runner_cls,
+ test_cfg: SchedulerTestConfig,
+):
+ snapshot_to_load = None
+ for epoch, expected_lr in enumerate(test_cfg.expected_lrs):
+ runner = _fit_runner_and_check_lrs(
+ runner_cls,
+ test_cfg.init_lr,
+ test_cfg.cfg,
+ [expected_lr], # trains for 1 epoch
+ snapshot_to_load=snapshot_to_load,
+ )
+ snapshot_to_load = dict(
+ metadata=dict(epoch=epoch + 1), scheduler=runner.scheduler.state_dict()
+ )
+
+
+@patch("deeplabcut.pose_estimation_pytorch.runners.train.CSVLogger", Mock())
+@pytest.mark.parametrize(
+ "runner_cls",
+ [
+ train_runners.PoseTrainingRunner,
+ train_runners.DetectorTrainingRunner,
+ ],
+)
+@pytest.mark.parametrize(
+ "test_cfg, resume_epoch",
+ [
+ (
+ SchedulerTestConfig(
+ cfg=dict(
+ type="LRListScheduler",
+ params=dict(milestones=[2, 5], lr_list=[[0.5], [0.1]]),
+ ),
+ init_lr=1.0,
+ expected_lrs=[1.0, 1.0, 0.5, 1.0, 1.0, 0.1, 0.1, 0.1],
+ ),
+ 3, # cut after the 3rd epoch - restart at LR=1 until epoch 5
+ ),
+ (
+ SchedulerTestConfig(
+ cfg=dict(type="StepLR", params=dict(step_size=4, gamma=0.1)),
+ init_lr=1.0,
+ expected_lrs=(4 * [1.0]) + (4 * [0.1]) + (4 * [0.01]) + (4 * [0.001]),
+ ),
+ 3, # cut after the 3rd epoch - restart at LR=1 and update at 4 correctly
+ ),
+ (
+ SchedulerTestConfig(
+ cfg=dict(type="StepLR", params=dict(step_size=4, gamma=0.1)),
+ init_lr=1.0,
+ expected_lrs=(4 * [1.0]) + [0.1, 1, 1, 1] + (4 * [0.1]),
+ ),
+ 5, # cut after the 5th epoch - restart at LR=1 and update again at 8
+ ),
+ ],
+)
+def test_resuming_training_with_no_scheduler_state(
+ runner_cls, test_cfg: SchedulerTestConfig, resume_epoch: int
+):
+ """
+ Without a scheduler config, there is no way to set the initial LR. All we can do is
+ set the last_epoch value, and adjust correctly at milestones going forward.
+ """
+ runner = _fit_runner_and_check_lrs(
+ runner_cls,
+ test_cfg.init_lr,
+ test_cfg.cfg,
+ test_cfg.expected_lrs[:resume_epoch],
+ )
+ assert runner.current_epoch == resume_epoch
+
+ runner = _fit_runner_and_check_lrs(
+ runner_cls,
+ test_cfg.init_lr,
+ test_cfg.cfg,
+ expected_lrs=test_cfg.expected_lrs[resume_epoch:],
+ snapshot_to_load=dict(metadata=dict(epoch=resume_epoch)),
+ )
+ assert runner.current_epoch == len(test_cfg.expected_lrs)
+
+
+def _fit_runner_and_check_lrs(
+ runner_cls,
+ init_lr: float,
+ scheduler_cfg: dict,
+ expected_lrs: list[float],
+ snapshot_to_load: dict | None = None,
+) -> train_runners.TrainingRunner:
+ runner_kwargs = dict(device="cpu", eval_interval=1_000_000)
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=init_lr)
+ scheduler = schedulers.build_scheduler(scheduler_cfg, optimizer)
+ num_epochs = len(expected_lrs)
+
+ base_path = "deeplabcut.pose_estimation_pytorch.runners"
+ with patch(f"{base_path}.base.Runner.load_snapshot") as base_mock_load:
+ with patch(f"{base_path}.train.PoseTrainingRunner.load_snapshot") as mock_load:
+ snapshot_path = None
+ base_mock_load.return_value = dict()
+ mock_load.return_value = dict()
+ if snapshot_to_load is not None:
+ snapshot_path = "fake_snapshot.pt"
+ base_mock_load.return_value = snapshot_to_load
+ mock_load.return_value = snapshot_to_load
+
+ print()
+ print(f"Scheduler: {scheduler}")
+ print(f"Starting training for {num_epochs} epochs")
+ runner = runner_cls(
+ model=Mock(),
+ optimizer=optimizer,
+ snapshot_manager=Mock(),
+ scheduler=scheduler,
+ snapshot_path=snapshot_path,
+ **runner_kwargs,
+ )
+
+ # Mock the step call; check that the learning rate is correct for the epoch
+ def step(*args, **kwargs):
+ # the current_epoch value is indexed at 1
+ total_epoch = runner.current_epoch - 1
+ epoch = total_epoch - runner.starting_epoch
+ _assert_learning_rates_match(total_epoch, optimizer, expected_lrs[epoch])
+ optimizer.step()
+ return dict(total_loss=0)
+
+ train_loader, val_loader = [Mock()], [Mock()]
+ runner.step = step
+ runner.fit(train_loader, val_loader, epochs=num_epochs, display_iters=1000)
+
+ return runner
+
+
+def _assert_learning_rates_match(e, optimizer, expected):
+ current_lrs = [g["lr"] for g in optimizer.param_groups]
+ print(f"Epoch {e}: LR={current_lrs}, expected={expected}")
+ for lr in current_lrs:
+ assert isinstance(lr, float)
+ np.testing.assert_almost_equal(lr, expected)
diff --git a/tests/pose_estimation_pytorch/runners/test_schedulers.py b/tests/pose_estimation_pytorch/runners/test_schedulers.py
new file mode 100644
index 0000000000..d3b17fa16c
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_schedulers.py
@@ -0,0 +1,268 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+"""Tests building schedulers from config"""
+import random
+from dataclasses import dataclass
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+
+import deeplabcut.pose_estimation_pytorch.runners.schedulers as schedulers
+
+
+def generate_random_lr_list(num_floats: int):
+ """Generate list of lists including random numbers.
+
+ Args:
+ num_floats: number of floats we want to include in our list
+
+ Returns:
+ ran_list: random list of sorted numbers, being first number bigger than the last
+ """
+ ran_list = []
+ for i in range(num_floats):
+ random_floats = [random.random()]
+ ran_list.append(random_floats)
+ return sorted(ran_list, reverse=True)
+
+
+@pytest.mark.parametrize(
+ "milestones, lr_list",
+ [
+ ([10, 430], [[0.05], [0.005]]),
+ (list(sorted(random.sample(range(0, 999), 2))), generate_random_lr_list(2))
+ ]
+)
+def test_scheduler(milestones, lr_list):
+ """Testing schedulers.py.
+
+ Given a list of milestones and a list of learning rates, this function tests
+ if the length of each list is the same. Furthermore, it will assess if
+ the current learning rate (output from the function we are testing) is a float
+ and corresponds to the expected learning rate given the milestones.
+
+ Args:
+ milestones: list of epochs indices (number of epochs)
+ lr_list: learning rates list
+
+ Returns:
+ None
+
+ Examples:
+ input:
+ milestones = [10,25,50]
+ lr_list = [[0.00001],[0.000005],[0.000001]]
+ """
+
+ assert len(milestones) == len(lr_list)
+
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=0.01)
+ s = schedulers.LRListScheduler(optimizer, milestones=milestones, lr_list=lr_list)
+
+ index_rng = range(milestones[0], milestones[1])
+ for i in range((milestones[-1]) + 1):
+ if i < milestones[0]:
+ expected_lr = [0.01]
+ elif i in index_rng:
+ expected_lr = lr_list[0]
+ else:
+ expected_lr = lr_list[1]
+
+ current_lr = s.get_lr()[0]
+ assert s.get_lr() == expected_lr
+ assert isinstance(current_lr, float)
+ optimizer.step()
+ s.step()
+
+
+@dataclass
+class SchedulerTestConfig:
+ cfg: dict
+ init_lr: float
+ expected_lrs: list[float]
+
+
+TEST_SCHEDULERS = [
+ SchedulerTestConfig(
+ cfg=dict(
+ type="LRListScheduler",
+ params=dict(milestones=[2, 5], lr_list=[[0.5], [0.1]])
+ ),
+ init_lr=1.0,
+ expected_lrs=[1.0, 1.0, 0.5, 0.5, 0.5, 0.1, 0.1, 0.1],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="LRListScheduler", params=dict(milestones=[1], lr_list=[[0.1]])),
+ init_lr=0.1,
+ expected_lrs=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="LRListScheduler", params=dict(milestones=[1], lr_list=[[0.5]])),
+ init_lr=0.1,
+ expected_lrs=[0.1, 0.5, 0.5, 0.5],
+ ),
+ SchedulerTestConfig(
+ cfg=dict(type="StepLR", params=dict(step_size=3, gamma=0.1)),
+ init_lr=1.0,
+ expected_lrs=[1.0, 1.0, 1.0, 0.1, 0.1, 0.1, 0.01, 0.01, 0.01, 0.001],
+ ),
+]
+
+
+@pytest.mark.parametrize("test_cfg", TEST_SCHEDULERS)
+def test_build_scheduler(test_cfg: SchedulerTestConfig) -> None:
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=test_cfg.init_lr)
+ s = schedulers.build_scheduler(test_cfg.cfg, optimizer)
+ print()
+ print(f"Scheduler: {s}")
+ num_epochs = len(test_cfg.expected_lrs)
+ for e in range(num_epochs):
+ _assert_learning_rates_match(e, optimizer, test_cfg.expected_lrs[e])
+ optimizer.step()
+ s.step()
+
+
+@pytest.mark.parametrize("test_cfg", TEST_SCHEDULERS)
+def test_resume_scheduler_after_each_epoch(test_cfg: SchedulerTestConfig) -> None:
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=test_cfg.init_lr)
+ s = schedulers.build_scheduler(test_cfg.cfg, optimizer)
+ print()
+ print(f"Scheduler: {s}")
+ num_epochs = len(test_cfg.expected_lrs)
+ for e in range(num_epochs):
+ _assert_learning_rates_match(e, optimizer, test_cfg.expected_lrs[e])
+ optimizer.step()
+ s.step()
+
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=test_cfg.init_lr)
+ new_scheduler = schedulers.build_scheduler(test_cfg.cfg, optimizer)
+ schedulers.load_scheduler_state(new_scheduler, s.state_dict())
+ s = new_scheduler
+
+
+@pytest.mark.parametrize(
+ "test_cfg, middle_epoch",
+ [
+ (TEST_SCHEDULERS[0], 3),
+ (TEST_SCHEDULERS[1], 5),
+ (TEST_SCHEDULERS[2], 2),
+ (TEST_SCHEDULERS[3], 2),
+ (TEST_SCHEDULERS[3], 3),
+ (TEST_SCHEDULERS[3], 4),
+ ],
+)
+def test_two_stage_training(test_cfg: SchedulerTestConfig, middle_epoch: int) -> None:
+ num_epochs = len(test_cfg.expected_lrs)
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=test_cfg.init_lr)
+ s = schedulers.build_scheduler(test_cfg.cfg, optimizer)
+
+ print()
+ print(f"Scheduler: {s}")
+ for e in range(middle_epoch):
+ _assert_learning_rates_match(e, optimizer, test_cfg.expected_lrs[e])
+ optimizer.step()
+ s.step()
+
+ optimizer = torch.optim.SGD([torch.randn(2, 2)], lr=test_cfg.init_lr)
+ new_scheduler = schedulers.build_scheduler(test_cfg.cfg, optimizer)
+ schedulers.load_scheduler_state(new_scheduler, s.state_dict())
+ s = new_scheduler
+ for e in range(middle_epoch, num_epochs):
+ _assert_learning_rates_match(e, optimizer, test_cfg.expected_lrs[e])
+ s.step()
+
+
+@pytest.mark.parametrize(
+ "data",
+ [
+ dict( # example with 3 warm-up epochs
+ config=dict(
+ dict(
+ type="ConstantLR",
+ params=dict(factor=0.1, total_iters=3),
+ ),
+ ),
+ start_lr=1.0,
+ expected_lrs=[[0.1], [0.1], [0.1], [1.0], [1.0]],
+ ),
+ dict( # example from torch.optim.lr_scheduler.SequentialLR
+ config=dict(
+ type="SequentialLR",
+ params=dict(
+ schedulers=[
+ dict(
+ type="ConstantLR",
+ params=dict(factor=0.1, total_iters=2),
+ ),
+ dict(type="ExponentialLR", params=dict(gamma=0.9)),
+ ],
+ milestones=[2],
+ ),
+ ),
+ start_lr=1.0,
+ expected_lrs=[[0.1], [0.1], [1.0], [0.9], [0.81], [0.729]],
+ ),
+ dict( # example from torch.optim.lr_scheduler.SequentialLR
+ config=dict(
+ type="SequentialLR",
+ params=dict(
+ schedulers=[
+ dict(
+ type="ConstantLR",
+ params=dict(factor=0.1, total_iters=2),
+ ),
+ dict(type="StepLR", params=dict(step_size=2, gamma=0.1)),
+ ],
+ milestones=[5],
+ ),
+ ),
+ start_lr=1.0,
+ expected_lrs=[
+ [0.1], [0.1], [1.0], [1.0], [1.0], # ConstantLR
+ [1.0], [1.0], [0.1], [0.1], [0.01], # StepLR
+ ],
+ ),
+ ],
+)
+def test_build_sequential_lr(data):
+ print("\nTESTING")
+ start_lr = data["start_lr"]
+ print(f"Start LR: {start_lr}")
+ model = nn.Linear(in_features=1, out_features=1)
+ optimizer = torch.optim.SGD(params=model.parameters(), lr=start_lr)
+
+ print("BUILDING")
+ scheduler = schedulers.build_scheduler(data["config"], optimizer)
+
+ print("RUNNING")
+ lrs = []
+ for epoch in range(len(data["expected_lrs"])):
+ lrs.append(scheduler.get_last_lr())
+ print(scheduler.get_last_lr())
+ scheduler.step()
+
+ print(f"Expected: {data['expected_lrs']}")
+ print(f"Actual: {lrs}")
+ np.testing.assert_allclose(
+ np.asarray(data["expected_lrs"]),
+ np.asarray(lrs),
+ atol=1e-10,
+ )
+
+
+def _assert_learning_rates_match(e, optimizer, expected):
+ current_lrs = [g["lr"] for g in optimizer.param_groups]
+ print(f"Epoch {e}: LR={current_lrs}, expected={expected}")
+ for lr in current_lrs:
+ assert isinstance(lr, float)
+ np.testing.assert_almost_equal(lr, expected)
diff --git a/tests/pose_estimation_pytorch/runners/test_task.py b/tests/pose_estimation_pytorch/runners/test_task.py
new file mode 100644
index 0000000000..7a0a9730c2
--- /dev/null
+++ b/tests/pose_estimation_pytorch/runners/test_task.py
@@ -0,0 +1,27 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+""" Tests the Task enum """
+import pytest
+
+from deeplabcut.pose_estimation_pytorch.task import Task
+
+
+@pytest.mark.parametrize(
+ "task, task_strings",
+ [
+ (Task.BOTTOM_UP, ["bu", "BU", "bU", "Bu"]),
+ (Task.TOP_DOWN, ["TD", "tD"]),
+ (Task.DETECT, ["dt", "DT"]),
+ ],
+)
+def test_build_task(task: Task, task_strings: list[str]):
+ for s in task_strings:
+ assert task == Task(s)
diff --git a/tests/test_crossvalutils.py b/tests/test_crossvalutils.py
index b3f1dbb4aa..3dcf1dcbe1 100644
--- a/tests/test_crossvalutils.py
+++ b/tests/test_crossvalutils.py
@@ -10,8 +10,7 @@
#
import numpy as np
import pickle
-from deeplabcut.pose_estimation_tensorflow.lib import crossvalutils
-
+from deeplabcut.core import crossvalutils
BEST_GRAPH = [14, 15, 16, 11, 22, 31, 61, 7, 59, 62, 64]
BEST_GRAPH_MONTBLANC = [1, 0, 2, 5, 4, 3]
diff --git a/tests/test_inferenceutils.py b/tests/test_inferenceutils.py
index 22877b1bfd..5b4c126f58 100644
--- a/tests/test_inferenceutils.py
+++ b/tests/test_inferenceutils.py
@@ -14,7 +14,7 @@
import pytest
from conftest import TEST_DATA_DIR
from copy import deepcopy
-from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils
+from deeplabcut.core import inferenceutils
from scipy.spatial.distance import squareform
@@ -61,17 +61,17 @@ def test_calc_object_keypoint_similarity(real_assemblies):
def test_match_assemblies(real_assemblies):
assemblies = real_assemblies[0]
- matched, unmatched = inferenceutils.match_assemblies(
+ num_gt, matches = inferenceutils.match_assemblies(
assemblies, assemblies[::-1], 0.01
)
- assert not unmatched
- for ass1, ass2, oks in matched:
- assert ass1 is ass2
- assert oks == 1
-
- matched, unmatched = inferenceutils.match_assemblies([], assemblies, 0.01)
- assert not matched
- assert all(ass1 is ass2 for ass1, ass2 in zip(unmatched, assemblies))
+ assert len(assemblies) == len(matches)
+ for m in matches:
+ assert m.prediction is m.ground_truth
+ assert m.oks == 1
+
+ num_gt, matches = inferenceutils.match_assemblies([], assemblies, 0.01)
+ assert len(matches) == 0
+ assert num_gt == len(assemblies)
def test_evaluate_assemblies(real_assemblies):
diff --git a/tests/test_predict_supermodel.py b/tests/test_predict_supermodel.py
index 767e2739a8..e10d211400 100644
--- a/tests/test_predict_supermodel.py
+++ b/tests/test_predict_supermodel.py
@@ -10,7 +10,7 @@
#
import numpy as np
import pytest
-from deeplabcut.modelzoo.api import superanimal_inference
+from deeplabcut.pose_estimation_tensorflow.modelzoo.api import superanimal_inference
def test_get_multi_scale_frames():
diff --git a/tests/test_trackingutils.py b/tests/test_trackingutils.py
index 1795db03ee..984fcc2a76 100644
--- a/tests/test_trackingutils.py
+++ b/tests/test_trackingutils.py
@@ -10,7 +10,7 @@
#
import numpy as np
import pytest
-from deeplabcut.pose_estimation_tensorflow.lib import trackingutils
+from deeplabcut.core import trackingutils
@pytest.fixture()
diff --git a/tests/utils/test_multiprocessing.py b/tests/utils/test_multiprocessing.py
new file mode 100644
index 0000000000..34333be81d
--- /dev/null
+++ b/tests/utils/test_multiprocessing.py
@@ -0,0 +1,37 @@
+#
+# DeepLabCut Toolbox (deeplabcut.org)
+# © A. & M.W. Mathis Labs
+# https://github.com/DeepLabCut/DeepLabCut
+#
+# Please see AUTHORS for contributors.
+# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS
+#
+# Licensed under GNU Lesser General Public License v3.0
+#
+import pytest
+import time
+from deeplabcut.utils.multiprocessing import call_with_timeout
+
+
+def _succeeding_method(parameter):
+ return parameter
+
+
+def _failing_method():
+ raise ValueError("Raise value error on purpose")
+
+
+def _hanging_method():
+ while True:
+ time.sleep(5)
+
+
+def test_call_with_timeout():
+ parameter = (10, "Hello test")
+ assert call_with_timeout(_succeeding_method, 30, parameter) == parameter
+
+ with pytest.raises(ValueError):
+ call_with_timeout(_failing_method, timeout=30)
+
+ with pytest.raises(TimeoutError):
+ call_with_timeout(_hanging_method, timeout=1)
diff --git a/testscript_cli.py b/testscript_cli.py
index 7b74543cdc..d808678497 100644
--- a/testscript_cli.py
+++ b/testscript_cli.py
@@ -20,6 +20,7 @@
# install("tensorflow==1.13.1")
import deeplabcut as dlc
+from deeplabcut.core.engine import Engine
from pathlib import Path
import pandas as pd
@@ -28,6 +29,8 @@
print("Imported DLC!")
+engine = Engine.TF
+
basepath = os.path.dirname(os.path.abspath("testscript_cli.py"))
videoname = "reachingvideo1"
video = [
@@ -116,7 +119,7 @@
print("CREATING TRAININGSET")
dlc.create_training_dataset(
- path_config_file, net_type=net_type, augmenter_type=augmenter_type
+ path_config_file, net_type=net_type, augmenter_type=augmenter_type, engine=engine,
)
posefile = os.path.join(
diff --git a/tools/README.md b/tools/README.md
index 32389da346..07632283a3 100644
--- a/tools/README.md
+++ b/tools/README.md
@@ -1,5 +1,11 @@
# Developer tools useful for maintaining the repository
+As developer you'll need:
+
+```bash
+pip install coverage pytest fnmatch black
+```
+
## Code headers
The code headers can be standardized by running
@@ -11,3 +17,25 @@ python tools/update_license_headers.py
from the repository root.
You can edit the `NOTICE.yml` to update the header.
+
+
+## Workflow for contributing/checking your code
+
+```bash
+black .
+```
+
+## Running the tests (locally)
+
+We use the pytest framework. You can just run:
+
+```bash
+pytest
+```
+
+For coverage run:
+
+```
+coverage run -m pytest
+coverage report
+```