diff --git a/.circleci/config.yml b/.circleci/config.yml index 00da359495..a8cb967fab 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ jobs: build-and-test: working_directory: ~/circleci-demo-python-django docker: - - image: circleci/python:3.8 # primary container for the build job + - image: circleci/python:3.10 # primary container for the build job auth: username: mydockerhub-user password: $DOCKERHUB_PASSWORD # context / project UI env-var reference diff --git a/.github/workflows/publish-book.yml b/.github/workflows/publish-book.yml index 426e87370e..6aff2725d2 100644 --- a/.github/workflows/publish-book.yml +++ b/.github/workflows/publish-book.yml @@ -11,15 +11,15 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.10 - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install .[tf,docs] + python -m pip install .[docs] pip install jupyter-book sphinxcontrib-mermaid - name: Build the book diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 04a4580456..c432acde6f 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -64,3 +64,5 @@ jobs: pip install git+https://github.com/${{ github.repository }}.git@${{ github.sha }} python examples/testscript.py python examples/testscript_multianimal.py + python examples/testscript_pytorch_single_animal.py + python examples/testscript_pytorch_multi_animal.py diff --git a/.gitignore b/.gitignore index 86a6943487..ad9f192703 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ _build/* /examples/3D* /examples/m3* /examples/OUT +/examples/pretrained* .local .DS_Store examples/.DS_Store @@ -18,6 +19,15 @@ examples/.DS_Store *.ckpt snapshot-* +# Modelzoo checkpoints +deeplabcut/modelzoo/checkpoints/ + +# PyTorch backbone weights +deeplabcut/pose_estimation_pytorch/models/backbones/pretrained_weights/ + +# Wandb files +wandb/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -115,7 +125,10 @@ ENV/ # Spyder project settings .spyderproject .spyproject + +# IDEs configurations .vscode/* +.idea/* # Rope project settings .ropeproject diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..a21cfeebbd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-added-large-files + - id: check-yaml + - id: end-of-file-fixer + - id: name-tests-test + - id: trailing-whitespace + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3 diff --git a/NOTICE.yml b/NOTICE.yml index 68ff6362df..ec8df871b2 100644 --- a/NOTICE.yml +++ b/NOTICE.yml @@ -5,7 +5,7 @@ https://github.com/DeepLabCut/DeepLabCut Please see AUTHORS for contributors. - https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS + https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS Licensed under GNU Lesser General Public License v3.0 include: @@ -17,6 +17,7 @@ # License for files adapted from DeeperCut by Eldar Insafutdinov # https://github.com/eldar/pose-tensorflow + # Applies to most files in deeplabcut.pose_estimation_tensorflow - header: | DeepLabCut Toolbox (deeplabcut.org) @@ -107,3 +108,8 @@ include: - deeplabcut/pose_tracking_pytorch/solver/scheduler_factory.py - deeplabcut/pose_tracking_pytorch/model/backones/vit_pytorch.py + +# PyTorch license + +- header: | + See https://github.com/pytorch/pytorch/blob/main/LICENSE diff --git a/README.md b/README.md index eea3ad6206..3d8319965a 100644 --- a/README.md +++ b/README.md @@ -60,15 +60,24 @@ Please click the link above for all the information you need to get started! Please note that currently we support only Python 3.10+ (see conda files for guidance). -Developers Stable Release: -- Very quick start: You need to have TensorFlow installed (up to v2.10 supported across platforms) `pip install "deeplabcut[gui,tf]"` that includes all functions plus GUIs, or `pip install deeplabcut[tf]` (headless version with PyTorch and TensorFlow). - -Developers Alpha Release: -- We also have an alpha release of PyTorch DeepLabCut available! [Please see here for instructions and information](https://github.com/DeepLabCut/DeepLabCut/blob/pytorch_docs/docs/pytorch/user_guide.md). +Developers Stable Release: very quick start (Python 3.10+ required) to install +DeepLabCut with the PyTorch engine + +- [Install PyTorch](https://pytorch.org/get-started/locally/) (**select the desired +CUDA version if you want to use a GPU**): `pip install torch torchvision` +- Then, [install `pytables`](https://www.pytables.org/usersguide/installation.html): `conda install -c conda-forge pytables==3.8.0` +- Finally, install `DeepLabCut` (with all functions + the GUI): +`pip install --pre "deeplabcut[gui]"` or `pip install --pre "deeplabcut"` (headless +version with PyTorch)! + +To use the TensorFlow engine (requires Python 3.10; TF up to v2.10 supported on Windows, +up to v2.12 on other platforms): you'll need to run `pip install "deeplabcut[gui,tf]"` +(which includes all functions plus GUIs) or `pip install "deeplabcut[tf]"` (headless +version with PyTorch and TensorFlow). We recommend using our conda file, see [here](https://github.com/DeepLabCut/DeepLabCut/blob/main/conda-environments/README.md) or the new [`deeplabcut-docker` package](https://github.com/DeepLabCut/DeepLabCut/tree/main/docker). -# [Documentation: The DeepLabCut Process](https://deeplabcut.github.io/DeepLabCut) +# [Documentation: The DeepLabCut Process](https://deeplabcut.github.io/DeepLabCut/README.html) Our docs walk you through using DeepLabCut, and key API points. For an overview of the toolbox and workflow for project management, see our step-by-step at [Nature Protocols paper](https://doi.org/10.1038/s41596-019-0176-0). @@ -82,9 +91,7 @@ For a deeper understanding and more resources for you to get started with Python 🐭 pose tracking of single animals demo [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_DEMO_mouse_openfield.ipynb) -🐭🐭🐭 pose tracking of multiple animals demo [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DeepLabCut/DeepLabCut/blob/master/examples/COLAB/COLAB_3miceDemo.ipynb) - -- See [more demos here](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/README.md). We provide data and several Jupyter Notebooks: one that walks you through a demo dataset to test your installation, and another Notebook to run DeepLabCut from the beginning on your own data. We also show you how to use the code in Docker, and on Google Colab. +See [more demos here](https://github.com/DeepLabCut/DeepLabCut/blob/main/examples/README.md). We provide data and several Jupyter Notebooks: one that walks you through a demo dataset to test your installation, and another Notebook to run DeepLabCut from the beginning on your own data. We also show you how to use the code in Docker, and on Google Colab. # Why use DeepLabCut? diff --git a/_config.yml b/_config.yml index c55ef8cebb..a24fcaf718 100644 --- a/_config.yml +++ b/_config.yml @@ -5,7 +5,7 @@ only_build_toc_files: true sphinx: config: - autodoc_mock_imports: ["wx"] + autodoc_mock_imports: ["wx", "matplotlib", "qtpy", "PySide6", "napari", "shiboken6"] mermaid_output_format: raw extra_extensions: - numpydoc diff --git a/_toc.yml b/_toc.yml index 471655454f..8935c8ff8e 100644 --- a/_toc.yml +++ b/_toc.yml @@ -46,10 +46,8 @@ parts: chapters: - file: docs/ModelZoo - file: docs/recipes/UsingModelZooPupil - - file: docs/recipes/MegaDetectorDLCLive - caption: 🧑‍🍳 Cookbook (detailed helper guides) chapters: - - file: docs/tutorial - file: docs/convert_maDLC - file: docs/recipes/OtherData - file: docs/recipes/io diff --git a/conda-environments/DEEPLABCUT.yaml b/conda-environments/DEEPLABCUT.yaml index 66edc47212..5b5c77d4c6 100644 --- a/conda-environments/DEEPLABCUT.yaml +++ b/conda-environments/DEEPLABCUT.yaml @@ -9,9 +9,13 @@ #Licensed under GNU Lesser General Public License v3.0 # # DeepLabCut environment -# FIRST: INSTALL CORRECT DRIVER for GPU, see https://stackoverflow.com/questions/30820513/what-is-the-correct-version-of-cuda-for-my-nvidia-driver/30820690 # -# AFTER THIS FILE IS INSTALLED, if you have a GPU be sure to install cudnn from conda-forge: conda install cudnn -c conda-forge +# FIRST: If you have an NVIDIA GPU and want to use it, check that you have drivers installed! +# To check if your GPUs are visible to PyTorch (and thus DeepLabCut), run: +# >>> python -c "import torch; print(torch.cuda.is_available())" +# +# If "False" is printed, PyTorch (and thus DeepLabCut) cannot access your GPU. For +# more information, see: https://pytorch.org/get-started/locally/ # # install: conda env create -f DEEPLABCUT.yaml # update: conda env update -f DEEPLABCUT.yaml @@ -29,4 +33,6 @@ dependencies: - ffmpeg - pytables==3.8.0 - pip: + - torch + - torchvision - "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut[gui,modelzoo,wandb]" diff --git a/conda-environments/DEEPLABCUT_M1.yaml b/conda-environments/DEEPLABCUT_M1.yaml deleted file mode 100644 index 47e8c02572..0000000000 --- a/conda-environments/DEEPLABCUT_M1.yaml +++ /dev/null @@ -1,43 +0,0 @@ -# DEEPLABCUT_M1.yaml - -#DeepLabCut2.0 Toolbox (deeplabcut.org) -#© A. & M.W. Mathis Labs -#https://github.com/DeepLabCut/DeepLabCut -#Please see AUTHORS for contributors. - -#https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS -#Licensed under GNU Lesser General Public License v3.0 -# -# DeepLabCut M1/M2 (Apple Silicon) environment instructions -# -# We'll get the miniconda M1 bash installer, as explained in -# https://docs.conda.io/projects/conda/en/latest/user-guide/install/macos.html -# -# In the Terminal, run the following commands: -# wget https://repo.anaconda.com/miniconda/Miniconda3-py39_4.12.0-MacOSX-arm64.sh -O ~/miniconda.sh -# bash ~/miniconda.sh -b -p $HOME/miniconda -# source ~/miniconda/bin/activate -# conda init zsh -# -# Then, `git clone DeepLabCut`, and run: -# -# conda env create -f conda-environments/DEEPLABCUT_M1.yaml -# -# Next, activate the environment, and launch DLC with pythonw -m deeplabcut - -name: DEEPLABCUT_M1 -channels: - - conda-forge - - defaults -dependencies: - - python=3.10 - - pip - - ipython - - jupyter - - nb_conda - - notebook<7.0.0 - - python.app - - ffmpeg - - apple::tensorflow-deps - - pip: - - "deeplabcut[gui,apple_mchips]" diff --git a/deeplabcut/__init__.py b/deeplabcut/__init__.py index 2da4b6a9f5..72dac1e3ce 100644 --- a/deeplabcut/__init__.py +++ b/deeplabcut/__init__.py @@ -12,10 +12,6 @@ import os -# Suppress tensorflow warning messages -import tensorflow as tf - -tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) DEBUG = True and "DEBUG" in os.environ and os.environ["DEBUG"] from deeplabcut.version import __version__, VERSION @@ -34,6 +30,7 @@ "DLC loaded in light mode; you cannot use any GUI (labeling, relabeling and standalone GUI)" ) +from deeplabcut.core.engine import Engine from deeplabcut.create_project import ( create_new_project, create_new_project_3d, @@ -49,6 +46,7 @@ mergeandsplit, ) from deeplabcut.generate_training_dataset import ( + create_training_dataset_from_existing_split, create_training_model_comparison, create_multianimaltraining_dataset, ) @@ -60,6 +58,9 @@ dropduplicatesinannotatinfiles, dropunlabeledframes, ) + +from deeplabcut.modelzoo.video_inference import video_inference_superanimal + from deeplabcut.utils import ( create_labeled_video, create_video_with_all_detections, @@ -92,13 +93,14 @@ ) # Train, evaluate & predict functions / all require TF -from deeplabcut.pose_estimation_tensorflow import ( +from deeplabcut.compat import ( train_network, return_train_network_path, evaluate_network, return_evaluate_network_data, analyze_videos, create_tracking_dataset, + analyze_images, analyze_time_lapse_frames, convert_detections2tracklets, extract_maps, @@ -107,7 +109,6 @@ visualize_paf, extract_save_all_maps, export_model, - video_inference_superanimal, ) diff --git a/deeplabcut/__main__.py b/deeplabcut/__main__.py index 93b3f44b64..8d8c782a74 100644 --- a/deeplabcut/__main__.py +++ b/deeplabcut/__main__.py @@ -9,20 +9,25 @@ # Licensed under GNU Lesser General Public License v3.0 # -try: - import PySide6 +def main(): + try: + import PySide6 - lite = False -except ModuleNotFoundError: - lite = True + lite = False + except ModuleNotFoundError: + lite = True -# if module is executed directly (i.e. `python -m deeplabcut.__init__`) launch straight into the GUI -if not lite: - print("Starting GUI...") - from deeplabcut.gui.launch_script import launch_dlc + # if module is executed directly (i.e. `python -m deeplabcut.__init__`) launch straight into the GUI + if not lite: + print("Starting GUI...") + from deeplabcut.gui.launch_script import launch_dlc - launch_dlc() -else: - print( - "You installed DLC lite, thus GUI's cannot be used. If you need GUI support please: pip install 'deeplabcut[gui]''" - ) + launch_dlc() + else: + print( + "You installed DLC lite, thus GUI's cannot be used. If you need GUI support please: pip install 'deeplabcut[gui]''" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/deeplabcut/benchmark/__init__.py b/deeplabcut/benchmark/__init__.py index 2e70eae786..a5ddd3c4a9 100644 --- a/deeplabcut/benchmark/__init__.py +++ b/deeplabcut/benchmark/__init__.py @@ -85,7 +85,11 @@ def evaluate( continue benchmark = benchmark_cls() for name in benchmark.names(): - if Result(method_name=name, benchmark_name=benchmark_cls.name) in results: + if Result( + code=benchmark.code, + method_name=name, + benchmark_name=benchmark_cls.name, + ) in results: continue else: result = benchmark.evaluate(name, on_error=on_error) diff --git a/deeplabcut/benchmark/base.py b/deeplabcut/benchmark/base.py index 49140c4436..ef6894ac3b 100644 --- a/deeplabcut/benchmark/base.py +++ b/deeplabcut/benchmark/base.py @@ -59,7 +59,7 @@ def get_predictions(self): raise NotImplementedError() def __init__(self): - keys = ["name", "keypoints", "ground_truth", "metadata"] + keys = ["code", "name", "keypoints", "ground_truth", "metadata"] for key in keys: if not hasattr(self, key): raise NotImplementedError( @@ -110,6 +110,7 @@ def evaluate(self, name: str, on_error="raise"): else: raise NotImplementedError() from exception return Result( + code=self.code, method_name=name, benchmark_name=self.name, mean_avg_precision=mean_avg_precision, @@ -139,6 +140,7 @@ def _validate_predictions(self, name: str, predictions: dict) -> dict: class Result: """Benchmark result.""" + code: str method_name: str benchmark_name: str root_mean_squared_error: float = float("nan") @@ -146,6 +148,7 @@ class Result: benchmark_version: str = __version__ _export_mapping = dict( + code="code", benchmark_name="benchmark", method_name="method", benchmark_version="version", diff --git a/deeplabcut/benchmark/benchmarks.py b/deeplabcut/benchmark/benchmarks.py index ee18e215c6..4068c29cf2 100644 --- a/deeplabcut/benchmark/benchmarks.py +++ b/deeplabcut/benchmark/benchmarks.py @@ -96,6 +96,16 @@ def compute_pose_map(self, results_objects): symmetric_kpts=[(0, 4), (1, 3)], ) + def _validate_predictions(self, name: str, predictions: dict) -> dict: + """Fixes filenames for predictions made on old versions of the dataset""" + return super()._validate_predictions( + name, + { + k.replace("Dummy", "D").replace("Dead pup", "DP"): v + for k, v in predictions.items() + }, + ) + class MarmosetBenchmark(deeplabcut.benchmark.base.Benchmark): """Dataset with two marmosets. diff --git a/deeplabcut/benchmark/metrics.py b/deeplabcut/benchmark/metrics.py index ba735a9db7..e73eb4cba2 100644 --- a/deeplabcut/benchmark/metrics.py +++ b/deeplabcut/benchmark/metrics.py @@ -29,8 +29,7 @@ import pandas as pd import deeplabcut.benchmark.utils -from deeplabcut.pose_estimation_tensorflow.core import evaluate_multianimal -from deeplabcut.pose_estimation_tensorflow.lib import inferenceutils +from deeplabcut.core import inferenceutils, crossvalutils from deeplabcut.utils.conversioncode import guarantee_multiindex_rows @@ -99,7 +98,7 @@ def calc_prediction_errors(preds, gt): if visible.size and xy_pred_.size: # Pick the predictions closest to ground truth, # rather than the ones the model has most confident in. - neighbors = evaluate_multianimal._find_closest_neighbors( + neighbors = crossvalutils.find_closest_neighbors( xy_gt_[visible], xy_pred_, k=3 ) found = neighbors != -1 @@ -213,6 +212,7 @@ def calc_map_from_obj( oks_sigma, margin=margin, symmetric_kpts=symmetric_kpts, + greedy_matching=True, ) return oks["mAP"] diff --git a/deeplabcut/benchmark/mot.py b/deeplabcut/benchmark/mot.py new file mode 100644 index 0000000000..32c99b6987 --- /dev/null +++ b/deeplabcut/benchmark/mot.py @@ -0,0 +1,150 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + +from __future__ import annotations + +import warnings + +import motmetrics as mm +import numpy as np +import pandas as pd +from numpy.typing import NDArray + +from deeplabcut.core import trackingutils + + +def _convert_bboxes_to_xywh(bboxes: NDArray, inplace: bool = False) -> NDArray: + w = bboxes[:, 2] - bboxes[:, 0] + h = bboxes[:, 3] - bboxes[:, 1] + if not inplace: + new_bboxes = bboxes.copy() + new_bboxes[:, 2] = w + new_bboxes[:, 3] = h + return new_bboxes + bboxes[:, 2] = w + bboxes[:, 3] = h + + +def reconstruct_bboxes_from_bodyparts( + data: pd.DataFrame, margin: float, to_xywh: bool = False +) -> NDArray: + x = data.xs("x", axis=1, level="coords") + y = data.xs("y", axis=1, level="coords") + p = data.xs("likelihood", axis=1, level="coords") + xy = np.stack([x, y], axis=2) + bboxes = np.full((data.shape[0], 5), np.nan) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + bboxes[:, :2] = np.nanmin(xy, axis=1) - margin + bboxes[:, 2:4] = np.nanmax(xy, axis=1) + margin + bboxes[:, 4] = np.nanmean(p, axis=1) + if to_xywh: + _convert_bboxes_to_xywh(bboxes, inplace=True) + return bboxes + + +def reconstruct_all_bboxes( + data: pd.DataFrame, margin: float, to_xywh: bool = False +) -> NDArray: + animals = data.columns.get_level_values("individuals").unique().tolist() + try: + animals.remove("single") + except ValueError: + pass + bboxes = np.full((len(animals), data.shape[0], 5), np.nan) + for n, animal in enumerate(animals): + bboxes[n] = reconstruct_bboxes_from_bodyparts( + data.xs(animal, axis=1, level="individuals"), margin, to_xywh + ) + return bboxes + + +def compute_mot_metrics( + h5_file_gt: str, + h5_file_pred: str, + tracker_type: str = "bbox", + **kwargs, +) -> mm.MOTAccumulator: + df_gt = pd.read_hdf(h5_file_gt) + df = pd.read_hdf(h5_file_pred) + if tracker_type == "bbox": + func = reconstruct_all_bboxes + elif tracker_type == "ellipse": + func = trackingutils.reconstruct_all_ellipses + else: + raise ValueError(f"Unrecognized tracker type {tracker_type}.") + + trackers_gt = func(df_gt, **kwargs) + trackers = func(df, **kwargs) + return _compute_mot_metrics( + trackers_gt, trackers, tracker_type, + ) + + +def _compute_mot_metrics( + trackers_ground_truth: NDArray, + trackers: NDArray, + tracker_type: str = "bbox", +) -> mm.MOTAccumulator: + if trackers_ground_truth.shape != trackers.shape: + raise ValueError( + "Dimensions mismatch. There must be as many `trackers_ground_truth` as there are `trackers`." + ) + + if tracker_type == "bbox": + sl = slice(0, 4) + cost_func = mm.distances.iou_matrix + elif tracker_type == "ellipse": + sl = slice(0, 5) + + def cost_func(ellipses_gt, ellipses_hyp): + cost_matrix = np.zeros((len(ellipses_gt), len(ellipses_hyp))) + gt_el = [trackingutils.Ellipse(*e[:5]) for e in ellipses_gt] + hyp_el = [trackingutils.Ellipse(*e[:5]) for e in ellipses_hyp] + for i, el in enumerate(gt_el): + for j, tracker in enumerate(hyp_el): + cost_matrix[i, j] = 1 - el.calc_similarity_with(tracker) + return cost_matrix + + else: + raise ValueError(f"Unrecognized tracker type {tracker_type}.") + + ids = np.arange(trackers_ground_truth.shape[0]) + acc = mm.MOTAccumulator(auto_id=True) + for i in range(trackers_ground_truth.shape[1]): + trackers_gt = trackers_ground_truth[:, i, sl] + trackers_hyp = trackers[:, i, sl] + empty_gt = np.isnan(trackers_gt).any(axis=1) + empty_hyp = np.isnan(trackers_hyp).any(axis=1) + trackers_gt = trackers_gt[~empty_gt] + trackers_hyp = trackers_hyp[~empty_hyp] + cost = cost_func(trackers_gt, trackers_hyp) + acc.update(ids[~empty_gt], ids[~empty_hyp], cost) + return acc + + +def print_all_metrics( + accumulators: list[mm.MOTAccumulator], all_params: list[str] | None = None +): + if not all_params: + names = [f"iter{i + 1}" for i in range(len(accumulators))] + else: + s = "_".join("{}" for _ in range(len(all_params[0]))) + names = [s.format(*params.values()) for params in all_params] + mh = mm.metrics.create() + summary = mh.compute_many( + accumulators, metrics=mm.metrics.motchallenge_metrics, names=names + ) + strsummary = mm.io.render_summary( + summary, formatters=mh.formatters, namemap=mm.io.motchallenge_metric_names + ) + print(strsummary) + return summary diff --git a/deeplabcut/compat.py b/deeplabcut/compat.py new file mode 100644 index 0000000000..b5973cfbcd --- /dev/null +++ b/deeplabcut/compat.py @@ -0,0 +1,1968 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Compatibility file for methods available with either PyTorch or Tensorflow""" +from __future__ import annotations + +from pathlib import Path +from typing import Iterable + +import numpy as np +from ruamel.yaml import YAML + +import deeplabcut.core.visualization as visualization +from deeplabcut.core.engine import Engine +from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine + +DEFAULT_ENGINE = Engine.PYTORCH + + +def get_project_engine(cfg: dict) -> Engine: + """ + Args: + cfg: the project configuration file + + Returns: + the engine specified for the project, or the default engine if none is specified + """ + if cfg.get("engine") is not None: + return Engine(cfg["engine"]) + + return DEFAULT_ENGINE + + +def get_available_aug_methods(engine: Engine) -> tuple[str, ...]: + """ + Args: + engine: the engine for which augmentation methods should be returned + + Returns: + the augmentations available for the given engine, where the first one is the + default method to use + + Raises: + RuntimeError: if no augmentations methods are defined for the given engine + """ + if engine == Engine.TF: + return "imgaug", "default", "deterministic", "scalecrop", "tensorpack" + elif engine == Engine.PYTORCH: + return ("albumentations",) + + raise RuntimeError(f"Unknown augmentation for engine: {engine}") + + +def train_network( + config: str | Path, + shuffle: int = 1, + trainingsetindex: int = 0, + max_snapshots_to_keep: int | None = None, + displayiters: int | None = None, + saveiters: int | None = None, + maxiters: int | None = None, + epochs: int | None = None, + save_epochs: int | None = None, + allow_growth: bool = True, + gputouse: str | None = None, + autotune: bool = False, + keepdeconvweights: bool = True, + modelprefix: str = "", + superanimal_name: str = "", + superanimal_transfer_learning: bool = False, + engine: Engine | None = None, + device: str | None = None, + snapshot_path: str | Path | None = None, + detector_path: str | Path | None = None, + batch_size: int | None = None, + detector_batch_size: int | None = None, + detector_epochs: int | None = None, + detector_save_epochs: int | None = None, + pose_threshold: float | None = 0.1, + pytorch_cfg_updates: dict | None = None, +): + """ + Trains the network with the labels in the training dataset. + + Parameters + ---------- + config : string + Full path of the config.yaml file as a string. + + shuffle: int, optional, default=1 + Integer value specifying the shuffle index to select for training. + + trainingsetindex: int, optional, default=0 + Integer specifying which TrainingsetFraction to use. + Note that TrainingFraction is a list in config.yaml. + + max_snapshots_to_keep: int or None + Sets how many snapshots are kept, i.e. states of the trained network. Every + saving iteration many times a snapshot is stored, however only the last + ``max_snapshots_to_keep`` many are kept! If you change this to None, then all + are kept. + See: https://github.com/DeepLabCut/DeepLabCut/issues/8#issuecomment-387404835 + + displayiters: optional, default=None + This variable is actually set in ``pose_config.yaml``. However, you can + overwrite it with this hack. Don't use this regularly, just if you are too lazy + to dig out the ``pose_config.yaml`` file for the corresponding project. If + ``None``, the value from there is used, otherwise it is overwritten! + + saveiters: optional, default=None + Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``: + you can use ``save_epochs``). + This variable is actually set in ``pose_config.yaml``. However, you can + overwrite it with this hack. Don't use this regularly, just if you are too lazy + to dig out the ``pose_config.yaml`` file for the corresponding project. + If ``None``, the value from there is used, otherwise it is overwritten! + + maxiters: optional, default=None + Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``: + you can use ``epochs``). + This variable is actually set in ``pose_config.yaml``. However, you can + overwrite it with this hack. Don't use this regularly, just if you are too lazy + to dig out the ``pose_config.yaml`` file for the corresponding project. + If ``None``, the value from there is used, otherwise it is overwritten! + + epochs: optional, default=None + Only for the PyTorch engine (equivalent to the `maxiters` parameter for the + TensorFlow engine). The maximum number of epochs to train the model for. If + None, the value will be read from the `pytorch_config.yaml` file. An epoch is a + single pass through the training dataset, which means your model has seen each + training image exactly once. So if you have 64 training images for your network, + an epoch is 64 iterations with batch size 1 (or 32 iterations with batch size 2, + 16 with batch size 4, etc.). + + save_epochs: optional, default=None + Only for the PyTorch engine (equivalent to the `saveiters` parameter for the + TensorFlow engine). The number of epochs between each snapshot save. If + None, the value will be read from the `pytorch_config.yaml` file. + + allow_growth: bool, optional, default=True. + Only for the TensorFlow engine. + For some smaller GPUs the memory issues happen. If ``True``, the memory + allocator does not pre-allocate the entire specified GPU memory region, instead + starting small and growing as needed. + See issue: https://forum.image.sc/t/how-to-stop-running-out-of-vram/30551/2 + + gputouse: optional, default=None + Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``: + you can use ``device``). + Natural number indicating the number of your GPU (see number in nvidia-smi). + If you do not have a GPU put None. + See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + autotune: bool, optional, default=False + Only for the TensorFlow engine. + Property of TensorFlow, somehow faster if ``False`` + (as Eldar found out, see https://github.com/tensorflow/tensorflow/issues/13317). + + keepdeconvweights: bool, optional, default=True + Also restores the weights of the deconvolution layers (and the backbone) when + training from a snapshot. Note that if you change the number of bodyparts, you + need to set this to false for re-training. + + modelprefix: str, optional, default="" + Directory containing the deeplabcut models to use when evaluating the network. + By default, the models are assumed to exist in the project folder. + + superanimal_name: str, optional, default ="" + Only for the TensorFlow engine. For the PyTorch engine, you need to specify + this through the ``weight_init`` when creating the training dataset. + Specified if transfer learning with superanimal is desired + + superanimal_transfer_learning: bool, optional, default = False. + Only for the TensorFlow engine. For the PyTorch engine, you need to specify + this through the ``weight_init`` when creating the training dataset. + If set true, the training is transfer learning (new decoding layer). If set + false, and superanimal_name is True, then the training is fine-tuning (reusing + the decoding layer) + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + device: str, optional, default = None. + Only for the PyTorch engine. The device to run the training on (e.g. "cuda:0") + + snapshot_path: str or Path, optional, default = None. + Only for the PyTorch engine. The path to the pose model snapshot to resume training from. + + detector_path: str or Path, optional, default = None. + Only for the PyTorch engine. The path to the detector model snapshot to resume training from. + + batch_size: int, optional, default = None. + Only for the PyTorch engine. The batch size to use while training. + + detector_batch_size: int, optional, default = None. + Only for the PyTorch engine. The batch size to use while training the detector. + + detector_epochs: int, optional, default = None. + Only for the PyTorch engine. The number of epochs to train the detector for. + + detector_save_epochs: int, optional, default = None. + Only for the PyTorch engine. The number of epochs between each detector snapshot save. + + pose_threshold: float, optional, default = 0.1. + Only for the PyTorch engine. Used for memory-replay. Pseudo-predictions with confidence lower + than this threshold are discarded for memory-replay + + pytorch_cfg_updates: dict, optional, default = None. + A dictionary of updates to the pytorch config. The keys are the dot-separated + paths to the values to update in the config. + For example, to update the gpus to run the training on, you can use: + ``` + pytorch_cfg_updates={"runner.gpus": [0,1,2,3]} + ``` + + Returns + ------- + None + + Examples + -------- + To train the network for first shuffle of the training dataset + + >>> deeplabcut.train_network('/analysis/project/reaching-task/config.yaml') + + To train the network for second shuffle of the training dataset + + >>> deeplabcut.train_network( + '/analysis/project/reaching-task/config.yaml', + shuffle=2, + keepdeconvweights=True, + ) + + To train the network for shuffle created with a PyTorch engine, while overriding the + number of epochs, batch size and other parameters. + + >>> deeplabcut.train_network( + '/analysis/project/reaching-task/config.yaml', + shuffle=1, + batch_size=8, + epochs=100, + save_epochs=10, + display_iters=50, + ) + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import train_network + + if max_snapshots_to_keep is None: + max_snapshots_to_keep = 5 + + return train_network( + str(config), + shuffle=shuffle, + trainingsetindex=trainingsetindex, + max_snapshots_to_keep=max_snapshots_to_keep, + displayiters=displayiters, + saveiters=saveiters, + maxiters=maxiters, + allow_growth=allow_growth, + gputouse=gputouse, + autotune=autotune, + keepdeconvweights=keepdeconvweights, + superanimal_name=superanimal_name, + superanimal_transfer_learning=superanimal_transfer_learning, + modelprefix=modelprefix, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis import train_network + + return train_network( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + modelprefix=modelprefix, + device=device, + snapshot_path=snapshot_path, + detector_path=detector_path, + load_head_weights=keepdeconvweights, + batch_size=batch_size, + epochs=epochs, + save_epochs=save_epochs, + detector_batch_size=detector_batch_size, + detector_epochs=detector_epochs, + detector_save_epochs=detector_save_epochs, + display_iters=displayiters, + max_snapshots_to_keep=max_snapshots_to_keep, + pose_threshold=pose_threshold, + pytorch_cfg_updates=pytorch_cfg_updates, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def return_train_network_path( + config, + shuffle: int = 1, + trainingsetindex: int = 0, + modelprefix: str = "", + engine: Engine | None = None, +) -> tuple[Path, Path, Path]: + """ + Returns the training and test pose config file names as well as the folder where the + snapshot is + + Parameters + ---------- + config : string + Full path of the config.yaml file as a string. + + shuffle: int + Integer value specifying the shuffle index to select for training. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note + that TrainingFraction is a list in config.yaml). + + modelprefix: str, optional + Directory containing the deeplabcut models to use when evaluating the network. + By default, the models are assumed to exist in the project folder. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Returns the triple: trainposeconfigfile, testposeconfigfile, snapshotfolder + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import return_train_network_path + + return return_train_network_path( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + modelprefix=modelprefix, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis.utils import ( + return_train_network_path, + ) + + return return_train_network_path( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + modelprefix=modelprefix, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def evaluate_network( + config: str | Path, + Shuffles: Iterable[int] = (1,), + trainingsetindex: int | str = 0, + plotting: bool | str = False, + show_errors: bool = True, + comparisonbodyparts: str | list[str] = "all", + gputouse: str | None = None, + rescale: bool = False, + modelprefix: str = "", + per_keypoint_evaluation: bool = False, + snapshots_to_evaluate: list[str] | None = None, + pcutoff: float | list[float] | dict[str, float] | None = None, + engine: Engine | None = None, + **torch_kwargs, +): + """Evaluates the network. + + Evaluates the network based on the saved models at different stages of the training + network. The evaluation results are stored in the .h5 and .csv file under the + subdirectory 'evaluation_results'. Change the snapshotindex parameter in the config + file to 'all' in order to evaluate all the saved models. + + Parameters + ---------- + config : string + Full path of the config.yaml file. + + Shuffles: list, optional, default=[1] + List of integers specifying the shuffle indices of the training dataset. + + trainingsetindex: int or str, optional, default=0 + Integer specifying which "TrainingsetFraction" to use. + Note that "TrainingFraction" is a list in config.yaml. This variable can also + be set to "all". + + plotting: bool or str, optional, default=False + Plots the predictions on the train and test images. + If provided it must be either ``True``, ``False``, ``"bodypart"``, or + ``"individual"``. Setting to ``True`` defaults as ``"bodypart"`` for + multi-animal projects. + If a detector is used, the predicted bounding boxes will also be plotted. + + show_errors: bool, optional, default=True + Display train and test errors. + + comparisonbodyparts: str or list, optional, default="all" + The average error will be computed for those body parts only. + The provided list has to be a subset of the defined body parts. + + gputouse: int or None, optional, default=None + Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a + GPU put `None``. + See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + rescale: bool, optional, default=False + Evaluate the model at the ``'global_scale'`` variable (as set in the + ``pose_config.yaml`` file for a particular project). I.e. every image will be + resized according to that scale and prediction will be compared to the resized + ground truth. The error will be reported in pixels at rescaled to the + *original* size. I.e. For a [200,200] pixel image evaluated at + ``global_scale=.5``, the predictions are calculated on [100,100] pixel images, + compared to 1/2*ground truth and this error is then multiplied by 2!. + The evaluation images are also shown for the original size! + + modelprefix: str, optional, default="" + Directory containing the deeplabcut models to use when evaluating the network. + By default, the models are assumed to exist in the project folder. + + per_keypoint_evaluation: bool, default=False + Compute the train and test RMSE for each keypoint, and save the results to + a {model_name}-keypoint-results.csv in the evalution-results folder + + snapshots_to_evaluate: List[str], optional, default=None + List of snapshot names to evaluate (e.g. ["snapshot-5000", "snapshot-7500"]). + + pcutoff: float | list[float] | dict[str, float] | None, default=None + Only for the PyTorch engine. For the TensorFlow engine, please set the pcutoff + in the `config.yaml` file. + The cutoff to use for computing evaluation metrics. When `None` (default), the + cutoff will be loaded from the project config. If a list is provided, there + should be one value for each bodypart and one value for each unique bodypart + (if there are any). If a dict is provided, the keys should be bodyparts + mapping to pcutoff values for each bodypart. Bodyparts that are not defined + in the dict will have pcutoff set to 0.6. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + torch_kwargs: + You can add any keyword arguments for the deeplabcut.pose_estimation_pytorch + evaluate_network function here. These arguments are passed to the downstream + function. Available parameters are `snapshotindex`, which overrides the + `snapshotindex` parameter in the project configuration file. For top-down models + the `detector_snapshot_index` parameter can override the index of the detector + to use for evaluation in the project configuration file. + + Returns + ------- + None + + Examples + -------- + If you do not want to plot and evaluate with shuffle set to 1. + + >>> deeplabcut.evaluate_network( + '/analysis/project/reaching-task/config.yaml', Shuffles=[1], + ) + + If you want to plot and evaluate with shuffle set to 0 and 1. + + >>> deeplabcut.evaluate_network( + '/analysis/project/reaching-task/config.yaml', + Shuffles=[0, 1], + plotting=True, + ) + + If you want to plot assemblies for a maDLC project + + >>> deeplabcut.evaluate_network( + '/analysis/project/reaching-task/config.yaml', + Shuffles=[1], + plotting="individual", + ) + + If you have a PyTorch model for which you want to set a different p-cutoff for + "left_ear" and "right_ear" bodyparts, and keep the one set in the project config + for other bodyparts: + + >>> deeplabcut.evaluate_network( + >>> "/analysis/project/reaching-task/config.yaml", + >>> Shuffles=[0, 1], + >>> pcutoff={"left_ear": 0.8, "right_ear": 0.8}, + >>> ) + + Note: This defaults to standard plotting for single-animal projects. + """ + if engine is None: + cfg = _load_config(config) + engines = set() + for shuffle in Shuffles: + engines.add( + get_shuffle_engine( + cfg, + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + ) + if len(engines) == 0: + raise ValueError( + f"You must pass at least one shuffle to evaluate (had {list(Shuffles)})" + ) + elif len(engines) > 1: + raise ValueError( + f"All shuffles must have the same engine (found {list(engines)})" + ) + engine = engines.pop() + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import evaluate_network + + return evaluate_network( + str(config), + Shuffles=Shuffles, + trainingsetindex=trainingsetindex, + plotting=plotting, + show_errors=show_errors, + comparisonbodyparts=comparisonbodyparts, + gputouse=gputouse, + rescale=rescale, + modelprefix=modelprefix, + per_keypoint_evaluation=per_keypoint_evaluation, + snapshots_to_evaluate=snapshots_to_evaluate, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis import evaluate_network + + _update_device(gputouse, torch_kwargs) + return evaluate_network( + config, + shuffles=Shuffles, + trainingsetindex=trainingsetindex, + plotting=plotting, + show_errors=show_errors, + comparison_bodyparts=comparisonbodyparts, + per_keypoint_evaluation=per_keypoint_evaluation, + modelprefix=modelprefix, + pcutoff=pcutoff, + **torch_kwargs, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def return_evaluate_network_data( + config: str, + shuffle: int = 0, + trainingsetindex: int = 0, + comparisonbodyparts: str | list[str] = "all", + Snapindex: str | int | None = None, + rescale: bool = False, + fulldata: bool = False, + show_errors: bool = True, + modelprefix: str = "", + returnjustfns: bool = True, + engine: Engine | None = None, +): + """ + Returns the results for (previously evaluated) network. deeplabcut.evaluate_network(..) + Returns list of (per model): [trainingsiterations,trainfraction,shuffle,trainerror,testerror,pcutoff,trainerrorpcutoff,testerrorpcutoff,Snapshots[snapindex],scale,net_type] + + This function is only implemented for tensorflow models/shuffles, and will throw + an error if called with a PyTorch shuffle. + + If fulldata=True, also returns (the complete annotation and prediction array) + Returns list of: (DataMachine, Data, data, trainIndices, testIndices, trainFraction, DLCscorer,comparisonbodyparts, cfg, Snapshots[snapindex]) + ---------- + config : string + Full path of the config.yaml file as a string. + + shuffle: integer + integers specifying shuffle index of the training dataset. The default is 0. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). This + variable can also be set to "all". + + comparisonbodyparts: list of bodyparts, Default is "all". + The average error will be computed for those body parts only (Has to be a subset of the body parts). + + rescale: bool, default False + Evaluate the model at the 'global_scale' variable (as set in the test/pose_config.yaml file for a particular project). I.e. every + image will be resized according to that scale and prediction will be compared to the resized ground truth. The error will be reported + in pixels at rescaled to the *original* size. I.e. For a [200,200] pixel image evaluated at global_scale=.5, the predictions are calculated + on [100,100] pixel images, compared to 1/2*ground truth and this error is then multiplied by 2!. The evaluation images are also shown for the + original size! + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Examples + -------- + If you do not want to plot + >>> deeplabcut._evaluate_network_data('/analysis/project/reaching-task/config.yaml', shuffle=[1]) + -------- + If you want to plot + >>> deeplabcut.evaluate_network('/analysis/project/reaching-task/config.yaml',shuffle=[1],plotting=True) + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import return_evaluate_network_data + + return return_evaluate_network_data( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + comparisonbodyparts=comparisonbodyparts, + Snapindex=Snapindex, + rescale=rescale, + fulldata=fulldata, + show_errors=show_errors, + modelprefix=modelprefix, + returnjustfns=returnjustfns, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def analyze_videos( + config: str, + videos: list[str], + videotype: str = "", + shuffle: int = 1, + trainingsetindex: int = 0, + gputouse: str | None = None, + save_as_csv: bool = False, + in_random_order: bool = True, + destfolder: str | None = None, + batchsize: int = None, + cropping: list[int] | None = None, + TFGPUinference: bool = True, + dynamic: tuple[bool, float, int] = (False, 0.5, 10), + modelprefix: str = "", + robust_nframes: bool = False, + allow_growth: bool = False, + use_shelve: bool = False, + auto_track: bool = True, + n_tracks: int | None = None, + animal_names: list[str] | None = None, + calibrate: bool = False, + identity_only: bool = False, + use_openvino: str | None = None, + engine: Engine | None = None, + **torch_kwargs, +): + """Makes prediction based on a trained network. + + The index of the trained network is specified by parameters in the config file + (in particular the variable 'snapshotindex'). + + The labels are stored as MultiIndex Pandas Array, which contains the name of + the network, body part name, (x, y) label position in pixels, and the + likelihood for each frame per body part. These arrays are stored in an + efficient Hierarchical Data Format (HDF) in the same directory where the video + is stored. However, if the flag save_as_csv is set to True, the data can also + be exported in comma-separated values format (.csv), which in turn can be + imported in many programs, such as MATLAB, R, Prism, etc. + + Parameters + ---------- + config: str + Full path of the config.yaml file. + + videos: list[str] + A list of strings containing the full paths to videos for analysis or a path to + the directory, where all the videos with same extension are stored. + + videotype: str, optional, default="" + Checks for the extension of the video in case the input to the video is a + directory. Only videos with this extension are analyzed. If left unspecified, + videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept. + + shuffle: int, optional, default=1 + An integer specifying the shuffle index of the training dataset used for + training the network. + + trainingsetindex: int, optional, default=0 + Integer specifying which TrainingsetFraction to use. + By default the first (note that TrainingFraction is a list in config.yaml). + + gputouse: int or None, optional, default=None + Only for the TensorFlow engine (for the PyTorch engine see the ``torch_kwargs``: + you can use ``device``). + Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a + GPU put ``None``. + See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + save_as_csv: bool, optional, default=False + Saves the predictions in a .csv file. + + in_random_order: bool, optional (default=True) + Whether or not to analyze videos in a random order. + This is only relevant when specifying a video directory in `videos`. + + destfolder: string or None, optional, default=None + Specifies the destination folder for analysis data. If ``None``, the path of + the video is used. Note that for subsequent analysis this folder also needs to + be passed. + + batchsize: int or None, optional, default=None + Currently not supported by the PyTorch engine. + Change batch size for inference; if given overwrites value in ``pose_cfg.yaml``. + + cropping: list or None, optional, default=None + Currently not supported by the PyTorch engine. + List of cropping coordinates as [x1, x2, y1, y2]. + Note that the same cropping parameters will then be used for all videos. + If different video crops are desired, run ``analyze_videos`` on individual + videos with the corresponding cropping coordinates. + + TFGPUinference: bool, optional, default=True + Only for the TensorFlow engine. + Perform inference on GPU with TensorFlow code. Introduced in "Pretraining + boosts out-of-domain robustness for pose estimation" by Alexander Mathis, + Mert Yüksekgönül, Byron Rogers, Matthias Bethge, Mackenzie W. Mathis. + Source: https://arxiv.org/abs/1909.11229 + + dynamic: tuple(bool, float, int) triple containing (state, det_threshold, margin) + If the state is true, then dynamic cropping will be performed. That means that + if an object is detected (i.e. any body part > detectiontreshold), then object + boundaries are computed according to the smallest/largest x position and + smallest/largest y position of all body parts. This window is expanded by the + margin and from then on only the posture within this crop is analyzed (until the + object is lost, i.e. >> deeplabcut.analyze_videos( + 'C:\\myproject\\reaching-task\\config.yaml', + ['C:\\yourusername\\rig-95\\Videos\\reachingvideo1.avi'], + ) + + Analyzing a single video on Linux/MacOS + + >>> deeplabcut.analyze_videos( + '/analysis/project/reaching-task/config.yaml', + ['/analysis/project/videos/reachingvideo1.avi'], + ) + + Analyze all videos of type ``avi`` in a folder + + >>> deeplabcut.analyze_videos( + '/analysis/project/reaching-task/config.yaml', + ['/analysis/project/videos'], + videotype='.avi', + ) + + Analyze multiple videos + + >>> deeplabcut.analyze_videos( + '/analysis/project/reaching-task/config.yaml', + [ + '/analysis/project/videos/reachingvideo1.avi', + '/analysis/project/videos/reachingvideo2.avi', + ], + ) + + Analyze multiple videos with ``shuffle=2`` + + >>> deeplabcut.analyze_videos( + '/analysis/project/reaching-task/config.yaml', + [ + '/analysis/project/videos/reachingvideo1.avi', + '/analysis/project/videos/reachingvideo2.avi', + ], + shuffle=2, + ) + + Analyze multiple videos with ``shuffle=2``, save results as an additional csv file + + >>> deeplabcut.analyze_videos( + '/analysis/project/reaching-task/config.yaml', + [ + '/analysis/project/videos/reachingvideo1.avi', + '/analysis/project/videos/reachingvideo2.avi', + ], + shuffle=2, + save_as_csv=True, + ) + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import analyze_videos + + kwargs = {} + if use_openvino is not None: # otherwise default comes from tensorflow API + kwargs["use_openvino"] = use_openvino + + return analyze_videos( + config, + videos, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + gputouse=gputouse, + save_as_csv=save_as_csv, + in_random_order=in_random_order, + destfolder=destfolder, + batchsize=batchsize, + cropping=cropping, + TFGPUinference=TFGPUinference, + dynamic=dynamic, + modelprefix=modelprefix, + robust_nframes=robust_nframes, + allow_growth=allow_growth, + use_shelve=use_shelve, + auto_track=auto_track, + n_tracks=n_tracks, + animal_names=animal_names, + calibrate=calibrate, + identity_only=identity_only, + **kwargs, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis import analyze_videos + + _update_device(gputouse, torch_kwargs) + + if batchsize is not None: + if "batch_size" in torch_kwargs: + print( + f"You called analyze_videos with parameters ``batchsize={batchsize}" + f"`` and batch_size={torch_kwargs['batch_size']}. Only one is " + f"needed/used. Using batch size {torch_kwargs['batch_size']}" + ) + else: + torch_kwargs["batch_size"] = batchsize + + return analyze_videos( + config, + videos=videos, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + save_as_csv=save_as_csv, + in_random_order=in_random_order, + destfolder=destfolder, + dynamic=dynamic, + modelprefix=modelprefix, + use_shelve=use_shelve, + robust_nframes=robust_nframes, + auto_track=auto_track, + n_tracks=n_tracks, + animal_names=animal_names, + calibrate=calibrate, + identity_only=identity_only, + overwrite=False, + cropping=cropping, + **torch_kwargs, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def create_tracking_dataset( + config: str, + videos: list[str], + track_method: str, + videotype: str = "", + shuffle: int = 1, + trainingsetindex: int = 0, + gputouse: int | None = None, + destfolder: str | None = None, + batchsize: int | None = None, + cropping: list[int] | None = None, + TFGPUinference: bool = True, + modelprefix: str = "", + robust_nframes: bool = False, + n_triplets: int = 1000, + engine: Engine | None = None, +) -> str: + """Creates a tracking dataset to train a ReID tracklet stitcher. + + Parameters + ---------- + config: str + Full path of the config.yaml file. + + videos: list[str] + A list of strings containing the full paths to videos from which to create a + tracking dataset, or a path to the directory where all the videos with same + extension are stored. + + track_method: str + Specifies the tracker used to generate the pose estimation data. Must be either + 'box', 'skeleton', or 'ellipse'. + + videotype: str, optional, default="" + Checks for the extension of the video in case the input to the video is a + directory. Only videos with this extension are analyzed. If left unspecified, + videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept. + + shuffle: int, optional, default=1 + An integer specifying the shuffle index of the training dataset used for + training the network. + + trainingsetindex: int, optional, default=0 + Integer specifying which TrainingsetFraction to use. + By default the first (note that TrainingFraction is a list in config.yaml). + + gputouse: int or None, optional, default=None + Only for the TensorFlow engine (for the PyTorch engine use ``device``). + Indicates the GPU to use (see number in ``nvidia-smi``). If you do not have a + GPU put ``None``. See: + https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + TFGPUinference: bool, optional, default=True + Only for the TensorFlow engine. + Perform inference on GPU with TensorFlow code. Introduced in "Pretraining + boosts out-of-domain robustness for pose estimation" by Alexander Mathis, + Mert Yüksekgönül, Byron Rogers, Matthias Bethge, Mackenzie W. Mathis. + Source: https://arxiv.org/abs/1909.11229 + + destfolder: + Specifies the destination folder for analysis data. If ``None``, the path of + the video is used. Note that for subsequent analysis this folder also needs to + be passed. + + modelprefix: str, optional, default="" + Directory containing the deeplabcut models to use when evaluating the network. + By default, the models are assumed to exist in the project folder. + + robust_nframes: bool, optional, default=False + Evaluate a video's number of frames in a robust manner. + This option is slower (as the whole video is read frame-by-frame), + but does not rely on metadata, hence its robustness against file corruption. + + n_triplets: int, default=1000 + The number of triplets to extract for the dataset. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Returns + ------- + DLCScorer: str + the scorer used to analyze the videos + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import create_tracking_dataset + + return create_tracking_dataset( + config, + videos, + track_method, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + gputouse=gputouse, + destfolder=destfolder, + batchsize=batchsize, + cropping=cropping, + TFGPUinference=TFGPUinference, + modelprefix=modelprefix, + robust_nframes=robust_nframes, + n_triplets=n_triplets, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis import create_tracking_dataset + return create_tracking_dataset( + config, + videos, + track_method, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + destfolder=destfolder, + batch_size=batchsize, + cropping=cropping, + modelprefix=modelprefix, + robust_nframes=robust_nframes, + n_triplets=n_triplets, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def analyze_images( + config: str | Path, + images: str | Path | list[str] | list[Path], + frame_type: str | None = None, + destfolder: str | Path | None = None, + shuffle: int = 1, + trainingsetindex: int = 0, + max_individuals: int | None = None, + device: str | None = None, + snapshot_index: int | None = None, + detector_snapshot_index: int | None = None, + save_as_csv: bool = False, + modelprefix: str = "", + plotting: bool | str = False, + pcutoff: float | None = None, + bbox_pcutoff: float | None = None, + plot_skeleton: bool = False, +) -> dict[str, dict[str, np.ndarray | np.ndarray]]: + """Analyzes images with a DeepLabCut model and stores the output in an H5 file. + + This method is only implemented for PyTorch models. + + The labels are stored as Pandas DataFrame, which contains the name of the network, + body part name, (x, y) label position in pixels, and the likelihood for each frame + per body part. + + Parameters + ---------- + config : str, Path + Full path of the project's config.yaml file. + + images: str, Path, list[str], list[Path] + The image(s) to run inference on. Can be the path to an image, the path + to a directory containing images, or a list of image paths or directories + containing images. + + frame_type: string, optional + Filters the images to analyze to only the ones with the given suffix (e.g. + setting `frame_type`=".png" will only analyze ".png" images). The default + behavior analyzes all ".jpg", ".jpeg" and ".png" images. + + destfolder: str, Path, optional + The directory where the predictions will be stored. If None, the predictions + will be stored in the same directory as the first image given in the `images` + argument (if it's a directory, that directory will be used; if it's an image, + the directory containing the image will be used). + + shuffle: int, optional + An integer specifying the shuffle with which to run image analysis. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default, the first one + is used (note that TrainingFraction is a list in config.yaml). + + max_individuals: int, optional + The maximum number of individuals to detect in each image. Set to the number of + individuals in the project if None. + + device: str, optional + The CUDA device to use for training. If None, the device will be taken from the + ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. For + more information, see https://pytorch.org/docs/stable/notes/cuda.html + + snapshot_index: int, optional + Index (starting at 0) of the snapshot to use for image analysis. To evaluate the + last one, use -1. Default uses the value set in the project config. + + detector_snapshot_index: int, optional + Only for Top-Down PyTorch models. If defined, uses the detector with the given + index for pose estimation. To evaluate the last one, use -1. Default uses the + value set in the project config. + + save_as_csv: bool, optional + Saves the predictions in a .csv file. The default is ``False``; if provided it + must be either ``True`` or ``False``. + + modelprefix: str, optional + Directory containing the deeplabcut models to use when running image analysis. + By default, the models are assumed to exist in the project folder. + + plotting: bool, str, default=False + Plots the predictions made by the model on the analyzed images. Results will be + stored in a folder named `LabeledImages_{scorer}`, where scorer is the name + of the model used to analyze the images. This folder will be in the same + directory as the file containing the predictions (either the given `destfolder`, + or the folder containing the first image to analyze). + + If provided it must be either ``True``, ``False``, ``"bodypart"``, or + ``"individual"``. Setting to ``True`` defaults as ``"bodypart"`` for + multi-animal projects. If a detector is used, the predicted bounding boxes + will also be plotted. + + pcutoff: float, optional, default=None + The cutoff score when plotting pose predictions. Must be None or in + (0, 1). If None, the pcutoff is read from the project configuration file. + + bbox_pcutoff: float, optional, default=None + The cutoff score when plotting bounding box predictions. Must be + None or in (0, 1). If None, it is read from the project configuration file. + + plot_skeleton: bool, default=False + If a skeleton is defined in the project's config.yaml, whether + to plot the skeleton connecting the predicted bodyparts on the images. + + Returns + ------- + A dictionary mapping image paths (as strings) to model predictions. + + Examples + -------- + If you want to analyze all frames in /analysis/project/my_images + >>> import deeplabcut + >>> deeplabcut.analyze_images( + >>> "/analysis/project/reaching-task/config.yaml", + >>> "/analysis/project/my_images", + >>> ) + >>> + + If you want to analyze two specific images with your shuffle 3 model: + >>> import deeplabcut + >>> deeplabcut.analyze_images( + >>> "/analysis/project/reaching-task/config.yaml", + >>> images=["image_001.png", "img_002.jpg"], + >>> shuffle=3, + >>> ) + >>> + + If you want to analyze frames in a folder, save them and plot predictions: + >>> import deeplabcut + >>> deeplabcut.analyze_images( + >>> "/analysis/project/reaching-task/config.yaml", + >>> "/analysis/project/my_images", + >>> shuffle=3, + >>> destfolder="/analysis/project/my_images_analyzed", + >>> plotting=True, + >>> ) + >>> + -------- + """ + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch import analyze_images + + return analyze_images( + config=config, + images=images, + frame_type=frame_type, + output_dir=destfolder, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + snapshot_index=snapshot_index, + detector_snapshot_index=detector_snapshot_index, + modelprefix=modelprefix, + device=device, + save_as_csv=save_as_csv, + max_individuals=max_individuals, + plotting=plotting, + pcutoff=pcutoff, + bbox_pcutoff=bbox_pcutoff, + plot_skeleton=plot_skeleton, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def analyze_time_lapse_frames( + config: str, + directory: str, + frametype: str = ".png", + shuffle: int = 1, + trainingsetindex: int = 0, + gputouse: int | None = None, + device: str | None = None, + save_as_csv: bool = False, + modelprefix: str = "", + engine: Engine | None = None, +): + """ + Analyzed all images (of type = frametype) in a folder and stores the output in one file. + + You can crop the frames (before analysis), by changing 'cropping'=True and setting + 'x1','x2','y1','y2' in the config file. + + Output: The labels are stored as MultiIndex Pandas Array, which contains the name + of the network, body part name, (x, y) label position in pixels, and the likelihood + for each frame per body part. These arrays are stored in an efficient Hierarchical + Data Format (HDF) in the same directory, where the video is stored. However, if the + flag save_as_csv is set to True, the data can also be exported in comma-separated + values format (.csv), which in turn can be imported in many programs, such as + MATLAB, R, Prism, etc. + + Parameters + ---------- + config : string + Full path of the config.yaml file as a string. + + directory: string + Full path to directory containing the frames that shall be analyzed + + frametype: string, optional + Checks for the file extension of the frames. Only images with this extension are + analyzed. The default is ``.png`` + + shuffle: int, optional + An integer specifying the shuffle index of the training dataset used for + training the network. The default is 1. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note + that TrainingFraction is a list in config.yaml). + + gputouse: int, optional. + Only for TensorFlow models. For PyTorch models, please use `device`. Natural + number indicating the number of your GPU (see number in nvidia-smi). If you do + not have a GPU put None. See: + https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + device: str, optional + The CUDA device to use for training. If None, the device will be taken from the + ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. For + more information, see https://pytorch.org/docs/stable/notes/cuda.html + + save_as_csv: bool, optional + Saves the predictions in a .csv file. The default is ``False``; if provided if + must be either ``True`` or ``False`` + + Examples + -------- + If you want to analyze all frames in /analysis/project/timelapseexperiment1 + >>> import deeplabcut + >>> deeplabcut.analyze_time_lapse_frames( + >>> '/analysis/project/reaching-task/config.yaml', + >>> '/analysis/project/timelapseexperiment1' + >>> ) + + -------- + + Note: for test purposes one can extract all frames from a video with ffmeg, e.g. + >>> ffmpeg -i testvideo.avi "thumb%04d.png" + + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import analyze_time_lapse_frames + + return analyze_time_lapse_frames( + config, + directory, + frametype=frametype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + gputouse=gputouse, + save_as_csv=save_as_csv, + modelprefix=modelprefix, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch import analyze_images + + return analyze_images( + config=config, + images=directory, + output_dir=directory, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + device=_gpu_to_use_to_device(gputouse, device), + save_as_csv=save_as_csv, + modelprefix=modelprefix, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def convert_detections2tracklets( + config: str, + videos: list[str], + videotype: str = "", + shuffle: int = 1, + trainingsetindex: int = 0, + overwrite: bool = False, + destfolder: str | None = None, + ignore_bodyparts: list[str] | None = None, + inferencecfg: dict | None = None, + modelprefix: str = "", + greedy: bool = False, + calibrate: bool = False, + window_size: int = 0, + identity_only: int = False, + track_method: str = "", + engine: Engine | None = None, +): + """ + This should be called at the end of deeplabcut.analyze_videos for multianimal projects! + + Parameters + ---------- + config : string + Full path of the config.yaml file as a string. + + videos : list + A list of strings containing the full paths to videos for analysis or a path to the directory, where all the videos with same extension are stored. + + videotype: string, optional + Checks for the extension of the video in case the input to the video is a directory.\n Only videos with this extension are analyzed. + If left unspecified, videos with common extensions ('avi', 'mp4', 'mov', 'mpeg', 'mkv') are kept. + + shuffle: int, optional + An integer specifying the shuffle index of the training dataset used for training the network. The default is 1. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). + + overwrite: bool, optional. + Overwrite tracks file i.e. recompute tracks from full detections and overwrite. + + destfolder: string, optional + Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this + folder also needs to be passed. + + ignore_bodyparts: optional + List of body part names that should be ignored during tracking (advanced). + By default, all the body parts are used. + + inferencecfg: Default is None. + Configuration file for inference (assembly of individuals). Ideally + should be obtained from cross validation (during evaluation). By default + the parameters are loaded from inference_cfg.yaml, but these get_level_values + can be overwritten. + + calibrate: bool, optional (default=False) + If True, use training data to calibrate the animal assembly procedure. + This improves its robustness to wrong body part links, + but requires very little missing data. + + window_size: int, optional (default=0) + Recurrent connections in the past `window_size` frames are + prioritized during assembly. By default, no temporal coherence cost + is added, and assembly is driven mainly by part affinity costs. + + identity_only: bool, optional (default=False) + If True and animal identity was learned by the model, + assembly and tracking rely exclusively on identity prediction. + + track_method: string, optional + Specifies the tracker used to generate the pose estimation data. + For multiple animals, must be either 'box', 'skeleton', or 'ellipse' + and will be taken from the config.yaml file if none is given. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Examples + -------- + If you want to convert detections to tracklets: + >>> import deeplabcut + >>> deeplabcut.convert_detections2tracklets( + >>> "/analysis/project/reaching-task/config.yaml", + >>> ["/analysis/project/video1.mp4"], + >>> videotype='.mp4', + >>> ) + + If you want to convert detections to tracklets based on box_tracker: + >>> import deeplabcut + >>> deeplabcut.convert_detections2tracklets( + >>> "/analysis/project/reaching-task/config.yaml", + >>> ["/analysis/project/video1.mp4"], + >>> videotype=".mp4", + >>> track_method="box", + >>> ) + + -------- + + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import convert_detections2tracklets + + return convert_detections2tracklets( + config, + videos, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + overwrite=overwrite, + destfolder=destfolder, + ignore_bodyparts=ignore_bodyparts, + inferencecfg=inferencecfg, + modelprefix=modelprefix, + greedy=greedy, + calibrate=calibrate, + window_size=window_size, + identity_only=identity_only, + track_method=track_method, + ) + + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis import convert_detections2tracklets + + if greedy or calibrate or window_size: + raise NotImplementedError( + f"The 'greedy', 'calibrate' and 'window_size' option are not yet " + f"implemented with {engine}" + ) + + return convert_detections2tracklets( + config, + videos, + videotype=videotype, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + overwrite=overwrite, + destfolder=destfolder, + ignore_bodyparts=ignore_bodyparts, + inferencecfg=inferencecfg, + modelprefix=modelprefix, + identity_only=identity_only, + track_method=track_method, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def extract_maps( + config, + shuffle: int = 0, + trainingsetindex: int = 0, + gputouse: int | None = None, + device: str | None = None, + rescale: bool = False, + Indices: list[int] | None = None, + modelprefix: str = "", + engine: Engine | None = None, +): + """ + Extracts the scoremap, locref, partaffinityfields (if available). + + Returns a dictionary indexed by: trainingsetfraction, snapshotindex, and imageindex + for those keys, each item contains: (image, scmap, locref, paf, bpt_names, + partaffinity_graph, imagename, True/False if this image was in trainingset). + + ---------- + config : string + Full path of the config.yaml file as a string. + + shuffle: integer + integers specifying shuffle index of the training dataset. The default is 0. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note + that TrainingFraction is a list in config.yaml). This variable can also be set + to "all". + + gputouse: int or None, optional, default=None + For the TensorFlow engine (for the PyTorch engine see ``device``). Specifies + the GPU to use (see number in ``nvidia-smi``). If you do not have a GPU put + ``None``. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + device: str or None, optional, default=None + The CUDA device to use for training. If None, the device will be taken from the + ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. See + https://pytorch.org/docs/stable/notes/cuda.html for more information. + + rescale: bool, default False + Evaluate the model at the 'global_scale' variable (as set in the test/pose_config.yaml file for a particular project). I.e. every + image will be resized according to that scale and prediction will be compared to the resized ground truth. The error will be reported + in pixels at rescaled to the *original* size. I.e. For a [200,200] pixel image evaluated at global_scale=.5, the predictions are calculated + on [100,100] pixel images, compared to 1/2*ground truth and this error is then multiplied by 2!. The evaluation images are also shown for the + original size! + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Examples + -------- + If you want to extract the data for image 0 and 103 (of the training set) for model trained with shuffle 0. + >>> deeplabcut.extract_maps(configfile,0,Indices=[0,103]) + + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import extract_maps + + return extract_maps( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + gputouse=gputouse, + rescale=rescale, + Indices=Indices, + modelprefix=modelprefix, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch import extract_maps + + return extract_maps( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + device=_gpu_to_use_to_device(gputouse, device), + rescale=rescale, + indices=Indices, + modelprefix=modelprefix, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def visualize_scoremaps(image: np.ndarray, scmap: np.ndarray): + """Plots scoremaps as an image overlay. + + Args: + image: An image as a numpy array of shape (h, w, channels) + scmap: A scoremap of shape (h, w) + + Returns: + The figure and axis on which the image scoremap was plot. + """ + return visualization.visualize_scoremaps(image, scmap) + + +def visualize_locrefs( + image: np.ndarray, + scmap: np.ndarray, + locref_x: np.ndarray, + locref_y: np.ndarray, + step: int = 5, + zoom_width: int = 0, +): + """Plots a scoremap and the corresponding location refinement field on an image. + + Args: + image: An image as a numpy array of shape (h, w, channels) + scmap: A scoremap of shape (h, w) + locref_x: The x-coordinate of the location refinement field, of shape (h, w) + locref_y: The y-coordinate of the location refinement field, of shape (h, w) + step: The step with which to plot the location refinement field. + zoom_width: The zoom width with which to plot the scoremaps. + + Returns: + The figure and axis on which the image scoremap and locref field were plot. + """ + return visualization.visualize_locrefs( + image, scmap, locref_x, locref_y, step=step, zoom_width=zoom_width + ) + + +def visualize_paf( + image: np.ndarray, + paf: np.ndarray, + step: int = 5, + colors: list | None = None, +): + """Plots the PAF on top of the image. + + Args: + image: Shape (height, width, channels). The image on which the model was run. + paf: Shape (height, width, 2 * len(paf_graph)). The PAF output by the model. + step: The step with which to plot the scoremaps. + colors: The colormap to use. + + Returns: + The figure and axis on which the image PAF was plot. + """ + return visualization.visualize_paf(image, paf, step=step, colors=colors) + + +def extract_save_all_maps( + config, + shuffle: int = 1, + trainingsetindex: int = 0, + comparisonbodyparts: str | list[str] = "all", + extract_paf: bool = True, + all_paf_in_one: bool = True, + gputouse: int | None = None, + device: str | None = None, + rescale: bool = False, + Indices: list[int] | None = None, + modelprefix: str = "", + dest_folder: str = None, + snapshot_index: int | str | None = None, + detector_snapshot_index: int | str | None = None, + engine: Engine | None = None, +): + """ + Extracts the scoremap, location refinement field and part affinity field prediction of the model. The maps + will be rescaled to the size of the input image and stored in the corresponding model folder in /evaluation-results. + + ---------- + config : string + Full path of the config.yaml file as a string. + + shuffle: integer + integers specifying shuffle index of the training dataset. The default is 1. + + trainingsetindex: int, optional + Integer specifying which TrainingsetFraction to use. By default the first (note that TrainingFraction is a list in config.yaml). This + variable can also be set to "all". + + comparisonbodyparts: list of bodyparts, Default is "all". + The average error will be computed for those body parts only (Has to be a subset of the body parts). + + extract_paf : bool + Extract part affinity fields by default. + Note that turning it off will make the function much faster. + + all_paf_in_one : bool + By default, all part affinity fields are displayed on a single frame. + If false, individual fields are shown on separate frames. + + gputouse: int or None, optional, default=None + For the TensorFlow engine (for the PyTorch engine see ``device``). Specifies + the GPU to use (see number in ``nvidia-smi``). If you do not have a GPU put + ``None``. See: https://nvidia.custhelp.com/app/answers/detail/a_id/3751/~/useful-nvidia-smi-queries + + device: str or None, optional, default=None + The CUDA device to use for training. If None, the device will be taken from the + ``pytorch_config.yaml`` file. Examples: {"cpu", "cuda", "cuda:0", "cuda:1"}. See + https://pytorch.org/docs/stable/notes/cuda.html for more information. + + Indices: default None + For which images shall the scmap/locref and paf be computed? Give a list of images + + nplots_per_row: int, optional (default=None) + Number of plots per row in grid plots. By default, calculated to approximate a squared grid of plots + + snapshot_index: Only for PyTorch models. Index (starting at 0) of the snapshot we + want to extract maps with. To evaluate the last one, use -1. To extract maps + for all snapshots, use "all". Default uses the value set in the project config. + + detector_snapshot_index: Only for TD PyTorch models. If defined, uses the detector + with the given index for pose estimation. To extract maps for all detector + snapshots, use "all". Default uses the value set in the project config. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Examples + -------- + Calculated maps for images 0, 1 and 33. + >>> deeplabcut.extract_save_all_maps('/analysis/project/reaching-task/config.yaml', shuffle=1,Indices=[0,1,33]) + + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(config), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import extract_save_all_maps + + return extract_save_all_maps( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + comparisonbodyparts=comparisonbodyparts, + extract_paf=extract_paf, + all_paf_in_one=all_paf_in_one, + gputouse=gputouse, + rescale=rescale, + Indices=Indices, + modelprefix=modelprefix, + dest_folder=dest_folder, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch import extract_save_all_maps + + return extract_save_all_maps( + config, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + comparison_bodyparts=comparisonbodyparts, + extract_paf=extract_paf, + all_paf_in_one=all_paf_in_one, + device=_gpu_to_use_to_device(gputouse, device), + rescale=rescale, + indices=Indices, + modelprefix=modelprefix, + snapshot_index=snapshot_index, + detector_snapshot_index=detector_snapshot_index, + dest_folder=dest_folder, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def export_model( + cfg_path: str, + shuffle: int = 1, + trainingsetindex: int = 0, + snapshotindex: int | None = None, + iteration: int = None, + TFGPUinference: bool = True, + overwrite: bool = False, + make_tar: bool = True, + wipepaths: bool = False, + without_detector: bool = False, + modelprefix: str = "", + engine: Engine | None = None, +) -> None: + """Export DeepLabCut models for the model zoo or for live inference. + + Saves the pose configuration, snapshot files, and frozen TF graph of the model to + directory named exported-models within the project directory (and an + `exported-models-pytorch` directory for PyTorch models). + + Parameters + ----------- + + cfg_path : string + path to the DLC Project config.yaml file + + shuffle : int, optional + the shuffle of the model to export. default = 1 + + trainingsetindex : int, optional + the index of the training fraction for the model you wish to export. default = 1 + + snapshotindex : int, optional + the snapshot index for the weights you wish to export. If None, + uses the snapshotindex as defined in 'config.yaml'. Default = None + + iteration : int, optional + The model iteration (active learning loop) you wish to export. If None, + the iteration listed in the config file is used. + + TFGPUinference : bool, optional + use the tensorflow inference model? Default = True + For inference using DeepLabCut-live, it is recommended to set TFGPIinference=False + + overwrite : bool, optional + if the model you wish to export has already been exported, whether to overwrite. default = False + + make_tar : bool, optional + Do you want to compress the exported directory to a tar file? Default = True + This is necessary to export to the model zoo, but not for live inference. + + wipepaths : bool, optional + Removes the actual path of your project and the init_weights from pose_cfg. + + without_detector: bool, optional + PyTorch engine only. Exports top-down models without the detector. + + engine: Engine, optional, default = None. + The default behavior loads the engine for the shuffle from the metadata. You can + overwrite this by passing the engine as an argument, but this should generally + not be done. + + Example: + -------- + Export the first stored snapshot for model trained with shuffle 3: + >>> deeplabcut.export_model('/analysis/project/reaching-task/config.yaml',shuffle=3, snapshotindex=-1) + -------- + """ + if engine is None: + engine = get_shuffle_engine( + _load_config(cfg_path), + trainingsetindex=trainingsetindex, + shuffle=shuffle, + modelprefix=modelprefix, + ) + + if engine == Engine.TF: + from deeplabcut.pose_estimation_tensorflow import export_model + + return export_model( + cfg_path=cfg_path, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + snapshotindex=snapshotindex, + iteration=iteration, + TFGPUinference=TFGPUinference, + overwrite=overwrite, + make_tar=make_tar, + wipepaths=wipepaths, + modelprefix=modelprefix, + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.apis.export import export_model + + return export_model( + config=cfg_path, + shuffle=shuffle, + trainingsetindex=trainingsetindex, + snapshotindex=snapshotindex, + iteration=iteration, + overwrite=overwrite, + wipe_paths=wipepaths, + without_detector=without_detector, + modelprefix=modelprefix, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def _update_device(gpu_to_use: int | None, torch_kwargs: dict) -> None: + if "device" not in torch_kwargs and gpu_to_use is not None: + device = _gpu_to_use_to_device(gpu_to_use, device=None) + if device is not None: + torch_kwargs["device"] = device + + +def _gpu_to_use_to_device(gpu_to_use: int | None, device: str | None) -> str | None: + if device is None and gpu_to_use is not None: + if isinstance(gpu_to_use, int): + device = f"cuda:{gpu_to_use}" + else: + device = gpu_to_use + + return device + + +def _load_config(config: str) -> dict: + config_path = Path(config) + if not config_path.exists(): + raise FileNotFoundError( + f"Config {config} is not found. Please make sure that the file exists." + ) + + with open(config, "r") as f: + project_config = YAML(typ="safe", pure=True).load(f) + + return project_config diff --git a/deeplabcut/core/__init__.py b/deeplabcut/core/__init__.py new file mode 100644 index 0000000000..117d127147 --- /dev/null +++ b/deeplabcut/core/__init__.py @@ -0,0 +1,10 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# diff --git a/deeplabcut/core/config.py b/deeplabcut/core/config.py new file mode 100644 index 0000000000..1a638e48da --- /dev/null +++ b/deeplabcut/core/config.py @@ -0,0 +1,74 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Simple helper methods related to configuration files stored in yaml files""" +from __future__ import annotations + +from pathlib import Path +from typing import Callable + +from ruamel.yaml import YAML + + +def read_config_as_dict(config_path: str | Path) -> dict: + """ + Args: + config_path: the path to the configuration file to load + + Returns: + The configuration file with pure Python classes + """ + with open(config_path, "r") as f: + cfg = YAML(typ="safe", pure=True).load(f) + + return cfg + + +def write_config(config_path: str | Path, config: dict, overwrite: bool = True) -> None: + """Writes a pose configuration file to disk + + Args: + config_path: the path where the config should be saved + config: the config to save + overwrite: whether to overwrite the file if it already exists + + Raises: + FileExistsError if overwrite=True and the file already exists + """ + if not overwrite and Path(config_path).exists(): + raise FileExistsError( + f"Cannot write to {config_path} - set overwrite=True to force" + ) + + with open(config_path, "w") as file: + YAML().dump(config, file) + + +def pretty_print( + config: dict, + indent: int = 0, + print_fn: Callable[[str], None] | None = None, +) -> None: + """Prints a model configuration in a pretty and readable way + + Args: + config: the config to print + indent: the base indent on all keys + print_fn: custom function to call (simply calls ``print`` if None) + """ + if print_fn is None: + print_fn = print + + for k, v in config.items(): + if isinstance(v, dict): + print_fn(f"{indent * ' '}{k}:") + pretty_print(v, indent + 2, print_fn=print_fn) + else: + print_fn(f"{indent * ' '}{k}: {v}") diff --git a/deeplabcut/core/conversion_table.py b/deeplabcut/core/conversion_table.py new file mode 100644 index 0000000000..e5d9679fa9 --- /dev/null +++ b/deeplabcut/core/conversion_table.py @@ -0,0 +1,79 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Defines conversion tables mapping DeepLabCut project bodyparts to SA bodyparts""" +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class ConversionTable: + """Maps DLC project bodyparts to the corresponding SuperAnimal bodyparts + + The conversion table must satisfy the following conditions (checked by validate): + - All SuperAnimal bodyparts must be valid (defined for the SuperAnimal model) + - All project bodyparts must be valid (defined for the DLC project) + """ + + super_animal: str + project_bodyparts: list[str] + super_animal_bodyparts: list[str] + table: dict[str, str] + + def __post_init__(self): + """Validates the table""" + self.validate() + + def to_array(self) -> np.ndarray: + """ + Returns: + An array mapping the indices of SuperAnimal bodyparts + + Raises: + ValueError: If the conversion table is misconfigured. + """ + self.validate() + sa_indices = {sa_bpt: i for i, sa_bpt in enumerate(self.super_animal_bodyparts)} + sa_bpt_ordering = [self.table[bpt] for bpt in self.converted_bodyparts()] + return np.array([sa_indices[sa_bpt] for sa_bpt in sa_bpt_ordering]) + + def converted_bodyparts(self) -> list[str]: + """Returns: The project bodyparts included in this ordered""" + return [bpt for bpt in self.project_bodyparts if bpt in self.table] + + def validate(self) -> None: + """ + Raises: + ValueError: If the conversion table is misconfigured. + """ + project_bpts = set(self.project_bodyparts) + sa_bpts = set(self.super_animal_bodyparts) + + mapped_sa = set(self.table.values()) + mapped_project = set(self.table.keys()) + + # check all mapped SuperAnimal bodyparts are in the config + if len(mapped_sa.difference(sa_bpts)) != 0: + extra_bodyparts = set(mapped_sa).difference(sa_bpts) + raise ValueError( + f"Some bodyparts in your mapping are not in the {self.super_animal} " + f"model: {extra_bodyparts}. Available bodyparts are {' '.join(sa_bpts)}" + ) + + # check all given bodyparts are in the project configuration + if len(mapped_project.difference(project_bpts)) != 0: + extra_bodyparts = mapped_project.difference(project_bpts) + raise ValueError( + "Some bodyparts in your mapping are not in your project configuration: " + f"{extra_bodyparts}. Defined bodyparts are {' '.join(project_bpts)}" + ) diff --git a/deeplabcut/core/crossvalutils.py b/deeplabcut/core/crossvalutils.py new file mode 100644 index 0000000000..e95b2c7591 --- /dev/null +++ b/deeplabcut/core/crossvalutils.py @@ -0,0 +1,484 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + + +import os +import pickle +import shutil +from collections import defaultdict +from copy import deepcopy + +import networkx as nx +import numpy as np +import pandas as pd +from scipy.spatial import cKDTree +from sklearn.metrics.cluster import contingency_matrix +from tqdm import tqdm + +from deeplabcut.core.inferenceutils import ( + _parse_ground_truth_data, + Assembler, + evaluate_assembly, +) +from deeplabcut.utils import auxfun_multianimal, auxiliaryfunctions + + +def _set_up_evaluation(data): + params = dict() + params["joint_names"] = data["metadata"]["all_joints_names"] + params["num_joints"] = len(params["joint_names"]) + partaffinityfield_graph = data["metadata"]["PAFgraph"] + params["paf"] = np.arange(len(partaffinityfield_graph)) + params["paf_graph"] = params["paf_links"] = [ + partaffinityfield_graph[l] for l in params["paf"] + ] + params["bpts"] = params["ibpts"] = range(params["num_joints"]) + params["imnames"] = [fn for fn in list(data) if fn != "metadata"] + return params + + +def _form_original_path(path): + root, filename = os.path.split(path) + base, ext = os.path.splitext(filename) + return os.path.join(root, filename.split("c")[0] + ext) + + +def _unsorted_unique(array): + _, inds = np.unique(array, return_index=True) + return np.asarray(array)[np.sort(inds)] + + +def find_closest_neighbors( + query: np.ndarray, ref: np.ndarray, k: int = 3 +) -> np.ndarray: + """Greedy matching of predicted keypoints to ground truth keypoints + + Args: + query: the query keypoints + ref: the reference keypoints + k: The list of k-th nearest neighbors to return. + + Returns: + an array of shape (len(query), ) containing the index of the closest + reference keypoint for each query keypoint + """ + n_preds = ref.shape[0] + tree = cKDTree(ref) + dist, inds = tree.query(query, k=k) + idx = np.argsort(dist[:, 0]) + neighbors = np.full(len(query), -1, dtype=int) + picked = {tree.n} + for i, ind in enumerate(inds[idx]): + for j in ind: + if j not in picked: + picked.add(j) + neighbors[idx[i]] = j + break + if len(picked) == (n_preds + 1): + break + return neighbors + + +def _calc_separability( + vals_left, vals_right, n_bins=101, metric="jeffries", max_sensitivity=False +): + if metric not in ("jeffries", "auc"): + raise ValueError("`metric` should be either 'jeffries' or 'auc'.") + + bins = np.linspace(0, 1, n_bins) + hist_left = np.histogram(vals_left, bins=bins)[0] + hist_left = hist_left / hist_left.sum() + hist_right = np.histogram(vals_right, bins=bins)[0] + hist_right = hist_right / hist_right.sum() + tpr = np.cumsum(hist_right) + if metric == "jeffries": + sep = np.sqrt( + 2 * (1 - np.sum(np.sqrt(hist_left * hist_right))) + ) # Jeffries-Matusita distance + else: + sep = np.trapz(np.cumsum(hist_left), tpr) + if max_sensitivity: + threshold = bins[max(1, np.argmax(tpr > 0))] + else: + threshold = bins[np.argmin(1 - np.cumsum(hist_left) + tpr)] + return sep, threshold + + +def _calc_within_between_pafs( + data, + metadata, + per_edge=True, + train_set_only=True, +): + data = deepcopy(data) + train_inds = set(metadata["data"]["trainIndices"]) + graph = data["metadata"]["PAFgraph"] + within_train = defaultdict(list) + within_test = defaultdict(list) + between_train = defaultdict(list) + between_test = defaultdict(list) + for i, (key, dict_) in enumerate(data.items()): + if key == "metadata": + continue + + is_train = i in train_inds + if train_set_only and not is_train: + continue + + df = dict_["groundtruth"][2] + try: + df.drop("single", level="individuals", inplace=True) + except KeyError: + pass + bpts = df.index.get_level_values("bodyparts").unique().to_list() + coords_gt = ( + df.unstack(["individuals", "coords"]) + .reindex(bpts, level="bodyparts") + .to_numpy() + .reshape((len(bpts), -1, 2)) + ) + if np.isnan(coords_gt).all(): + continue + + coords = dict_["prediction"]["coordinates"][0] + # Get animal IDs and corresponding indices in the arrays of detections + lookup = dict() + for i, (coord, coord_gt) in enumerate(zip(coords, coords_gt)): + inds = np.flatnonzero(np.all(~np.isnan(coord), axis=1)) + inds_gt = np.flatnonzero(np.all(~np.isnan(coord_gt), axis=1)) + if inds.size and inds_gt.size: + neighbors = find_closest_neighbors(coord_gt[inds_gt], coord[inds], k=3) + found = neighbors != -1 + lookup[i] = dict(zip(inds_gt[found], inds[neighbors[found]])) + + costs = dict_["prediction"]["costs"] + for k, v in costs.items(): + paf = v["m1"] + mask_within = np.zeros(paf.shape, dtype=bool) + s, t = graph[k] + if s not in lookup or t not in lookup: + continue + lu_s = lookup[s] + lu_t = lookup[t] + common_id = set(lu_s).intersection(lu_t) + for id_ in common_id: + mask_within[lu_s[id_], lu_t[id_]] = True + within_vals = paf[mask_within] + between_vals = paf[~mask_within] + if is_train: + within_train[k].extend(within_vals) + between_train[k].extend(between_vals) + else: + within_test[k].extend(within_vals) + between_test[k].extend(between_vals) + if not per_edge: + within_train = np.concatenate([*within_train.values()]) + within_test = np.concatenate([*within_test.values()]) + between_train = np.concatenate([*between_train.values()]) + between_test = np.concatenate([*between_test.values()]) + return (within_train, within_test), (between_train, between_test) + + +def _benchmark_paf_graphs( + config, + inference_cfg, + data, + paf_inds, + greedy=False, + add_discarded=True, + identity_only=False, + calibration_file="", + oks_sigma=0.1, + margin=0, + symmetric_kpts=None, + split_inds=None, +): + metadata = data.pop("metadata") + multi_bpts_orig = auxfun_multianimal.extractindividualsandbodyparts(config)[2] + multi_bpts = [j for j in metadata["all_joints_names"] if j in multi_bpts_orig] + n_multi = len(multi_bpts) + data_ = {"metadata": metadata} + for k, v in data.items(): + data_[k] = v["prediction"] + ass = Assembler( + data_, + max_n_individuals=inference_cfg["topktoretain"], + n_multibodyparts=n_multi, + greedy=greedy, + pcutoff=inference_cfg.get("pcutoff", 0.1), + min_affinity=inference_cfg.get("pafthreshold", 0.1), + add_discarded=add_discarded, + identity_only=identity_only, + ) + if calibration_file: + ass.calibrate(calibration_file) + + params = ass.metadata + image_paths = params["imnames"] + bodyparts = params["joint_names"] + idx = ( + data[image_paths[0]]["groundtruth"][2] + .unstack("coords") + .reindex(bodyparts, level="bodyparts") + .index + ) + mask_multi = idx.get_level_values("individuals") != "single" + if not mask_multi.all(): + idx = idx.drop("single", level="individuals") + individuals = idx.get_level_values("individuals").unique() + n_individuals = len(individuals) + map_ = dict(zip(individuals, range(n_individuals))) + + # Form ground truth beforehand + ground_truth = [] + for i, imname in enumerate(image_paths): + temp = data[imname]["groundtruth"][2].reindex(multi_bpts, level="bodyparts") + ground_truth.append(temp.to_numpy().reshape((-1, 2))) + ground_truth = np.stack(ground_truth) + temp = np.ones((*ground_truth.shape[:2], 3)) + temp[..., :2] = ground_truth + temp = temp.reshape((temp.shape[0], n_individuals, -1, 3)) + ass_true_dict = _parse_ground_truth_data(temp) + ids = np.vectorize(map_.get)(idx.get_level_values("individuals").to_numpy()) + ground_truth = np.insert(ground_truth, 2, ids, axis=2) + + # Assemble animals on the full set of detections + paf_inds = sorted(paf_inds, key=len) + n_graphs = len(paf_inds) + all_scores = [] + all_metrics = [] + all_assemblies = [] + for j, paf in enumerate(paf_inds, start=1): + print(f"Graph {j}|{n_graphs}") + ass.paf_inds = paf + ass.assemble() + all_assemblies.append((ass.assemblies, ass.unique, ass.metadata["imnames"])) + if split_inds is not None: + oks = [] + + # get the indices of the images in the training set + dataset_idx = [data[image_name]["index"] for image_name in image_paths] + for inds in split_inds: + ass_gt = { + k: v for k, v in ass_true_dict.items() if dataset_idx[k] in inds + } + ass_pred = { + k: v for k, v in ass.assemblies.items() if dataset_idx[k] in inds + } + + oks.append( + evaluate_assembly( + ass_pred, + ass_gt, + oks_sigma, + margin=margin, + symmetric_kpts=symmetric_kpts, + greedy_matching=inference_cfg.get("greedy_oks", False), + ) + ) + else: + oks = evaluate_assembly( + ass.assemblies, + ass_true_dict, + oks_sigma, + margin=margin, + symmetric_kpts=symmetric_kpts, + greedy_matching=inference_cfg.get("greedy_oks", False), + ) + all_metrics.append(oks) + scores = np.full((len(image_paths), 2), np.nan) + for i, imname in enumerate(tqdm(image_paths)): + gt = ground_truth[i] + gt = gt[~np.isnan(gt).any(axis=1)] + if len(np.unique(gt[:, 2])) < 2: # Only consider frames with 2+ animals + continue + + # Count the number of unassembled bodyparts + n_dets = len(gt) + animals = ass.assemblies.get(i) + if animals is None: + if n_dets: + scores[i, 0] = 1 + else: + animals = [ + np.c_[animal.data, np.ones(animal.data.shape[0]) * n] + for n, animal in enumerate(animals) + ] + hyp = np.concatenate(animals) + hyp = hyp[~np.isnan(hyp).any(axis=1)] + scores[i, 0] = max(0, (n_dets - hyp.shape[0]) / n_dets) + neighbors = find_closest_neighbors(gt[:, :2], hyp[:, :2]) + valid = neighbors != -1 + id_gt = gt[valid, 2] + id_hyp = hyp[neighbors[valid], -1] + mat = contingency_matrix(id_gt, id_hyp) + purity = mat.max(axis=0).sum() / mat.sum() + scores[i, 1] = purity + all_scores.append((scores, paf)) + + dfs = [] + for score, inds in all_scores: + df = pd.DataFrame(score, columns=["miss", "purity"]) + df["ngraph"] = len(inds) + dfs.append(df) + big_df = pd.concat(dfs) + group = big_df.groupby("ngraph") + return (all_scores, group.agg(["mean", "std"]).T, all_metrics, all_assemblies) + + +def _get_n_best_paf_graphs( + data, + metadata, + full_graph, + n_graphs=10, + root=None, + which="best", + ignore_inds=None, + metric="auc", +): + if which not in ("best", "worst"): + raise ValueError('`which` must be either "best" or "worst"') + + (within_train, _), (between_train, _) = _calc_within_between_pafs( + data, + metadata, + train_set_only=True, + ) + # Handle unlabeled bodyparts... + existing_edges = set(k for k, v in within_train.items() if v) + if ignore_inds is not None: + existing_edges = existing_edges.difference(ignore_inds) + existing_edges = list(existing_edges) + + if not any(between_train.values()): + # Only 1 animal, let us return the full graph indices only + return ([existing_edges], dict(zip(existing_edges, [0] * len(existing_edges)))) + + scores, _ = zip( + *[ + _calc_separability(between_train[n], within_train[n], metric=metric) + for n in existing_edges + ] + ) + + # Find minimal skeleton + G = nx.Graph() + for edge, score in zip(existing_edges, scores): + if np.isfinite(score): + G.add_edge(*full_graph[edge], weight=score) + if which == "best": + order = np.asarray(existing_edges)[np.argsort(scores)[::-1]] + if root is None: + root = [] + for edge in nx.maximum_spanning_edges(G, data=False): + root.append(full_graph.index(sorted(edge))) + else: + order = np.asarray(existing_edges)[np.argsort(scores)] + if root is None: + root = [] + for edge in nx.minimum_spanning_edges(G, data=False): + root.append(full_graph.index(sorted(edge))) + + n_edges = len(existing_edges) - len(root) + lengths = np.linspace(0, n_edges, min(n_graphs, n_edges + 1), dtype=int)[1:] + order = order[np.isin(order, root, invert=True)] + paf_inds = [root] + for length in lengths: + paf_inds.append(root + list(order[:length])) + return paf_inds, dict(zip(existing_edges, scores)) + + +def cross_validate_paf_graphs( + config, + inference_config, + full_data_file, + metadata_file, + output_name="", + pcutoff=0.1, + oks_sigma=0.1, + margin=0, + greedy=False, + add_discarded=True, + calibrate=False, + overwrite_config=True, + n_graphs=10, + paf_inds=None, + symmetric_kpts=None, +): + cfg = auxiliaryfunctions.read_config(config) + inf_cfg = auxiliaryfunctions.read_plainconfig(inference_config) + inf_cfg_temp = inf_cfg.copy() + inf_cfg_temp["pcutoff"] = pcutoff + + with open(full_data_file, "rb") as file: + data = pickle.load(file) + with open(metadata_file, "rb") as file: + metadata = pickle.load(file) + + params = _set_up_evaluation(data) + to_ignore = auxfun_multianimal.filter_unwanted_paf_connections( + cfg, params["paf_graph"] + ) + best_graphs = _get_n_best_paf_graphs( + data, + metadata, + params["paf_graph"], + ignore_inds=to_ignore, + n_graphs=n_graphs, + ) + paf_scores = best_graphs[1] + if paf_inds is None: + paf_inds = best_graphs[0] + + if calibrate: + trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg) + calibration_file = os.path.join( + cfg["project_path"], + str(trainingsetfolder), + "CollectedData_" + cfg["scorer"] + ".h5", + ) + else: + calibration_file = "" + + results = _benchmark_paf_graphs( + cfg, + inf_cfg_temp, + data, + paf_inds, + greedy, + add_discarded, + oks_sigma=oks_sigma, + margin=margin, + symmetric_kpts=symmetric_kpts, + calibration_file=calibration_file, + split_inds=[ + metadata["data"]["trainIndices"], + metadata["data"]["testIndices"], + ], + ) + # Select optimal PAF graph + df = results[1] + size_opt = np.argmax((1 - df.loc["miss", "mean"]) * df.loc["purity", "mean"]) + pose_config = inference_config.replace("inference_cfg", "pose_cfg") + if not overwrite_config: + shutil.copy(pose_config, pose_config.replace(".yaml", "_old.yaml")) + inds = list(paf_inds[size_opt]) + auxiliaryfunctions.edit_config( + pose_config, {"paf_best": [int(ind) for ind in inds]} + ) + if output_name: + with open(output_name, "wb") as file: + pickle.dump([results], file) + return results[:3], paf_scores, results[3][size_opt] + + +# Backwards compatibility +_find_closest_neighbors = find_closest_neighbors diff --git a/deeplabcut/core/engine.py b/deeplabcut/core/engine.py new file mode 100644 index 0000000000..c6f07ca69d --- /dev/null +++ b/deeplabcut/core/engine.py @@ -0,0 +1,49 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Defines the deep learning frameworks available""" +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +@dataclass(frozen=True) +class EngineDataMixin: + aliases: tuple[str] + model_folder_name: str + pose_cfg_name: str + results_folder_name: str + + +class Engine(EngineDataMixin, Enum): + PYTORCH = ( + ("pytorch", "torch"), + "dlc-models-pytorch", + "pytorch_config.yaml", + "evaluation-results-pytorch", + ) + TF = ( + ("tensorflow", "tf"), + "dlc-models", + "pose_cfg.yaml", + "evaluation-results", + ) + + @classmethod + def _missing_(cls, value): + if isinstance(value, str): + for member in cls: + if value.lower() in member.aliases: + return member + return None + + def __repr__(self) -> str: + return f"Engine.{self.name}" diff --git a/deeplabcut/core/inferenceutils.py b/deeplabcut/core/inferenceutils.py new file mode 100644 index 0000000000..cbd21a877a --- /dev/null +++ b/deeplabcut/core/inferenceutils.py @@ -0,0 +1,1314 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from __future__ import annotations + +import heapq +import itertools +import multiprocessing +import operator +import pickle +import warnings +from collections import defaultdict +from dataclasses import dataclass +from math import erf, sqrt +from typing import Any, Iterable, Tuple + +import networkx as nx +import numpy as np +import pandas as pd +from scipy.optimize import linear_sum_assignment +from scipy.spatial import cKDTree +from scipy.spatial.distance import cdist, pdist +from scipy.special import softmax +from scipy.stats import chi2, gaussian_kde +from tqdm import tqdm + + +def _conv_square_to_condensed_indices(ind_row, ind_col, n): + if ind_row == ind_col: + raise ValueError("There are no diagonal elements in condensed matrices.") + + if ind_row < ind_col: + ind_row, ind_col = ind_col, ind_row + return n * ind_col - ind_col * (ind_col + 1) // 2 + ind_row - 1 - ind_col + + +Position = Tuple[float, float] + + +@dataclass(frozen=True) +class Joint: + pos: Position + confidence: float = 1.0 + label: int = None + idx: int = None + group: int = -1 + + +class Link: + def __init__(self, j1, j2, affinity=1): + self.j1 = j1 + self.j2 = j2 + self.affinity = affinity + self._length = sqrt((j1.pos[0] - j2.pos[0]) ** 2 + (j1.pos[1] - j2.pos[1]) ** 2) + + def __repr__(self): + return ( + f"Link {self.idx}, affinity={self.affinity:.2f}, length={self.length:.2f}" + ) + + @property + def confidence(self): + return self.j1.confidence * self.j2.confidence + + @property + def idx(self): + return self.j1.idx, self.j2.idx + + @property + def length(self): + return self._length + + @length.setter + def length(self, length): + self._length = length + + def to_vector(self): + return [*self.j1.pos, *self.j2.pos] + + +class Assembly: + def __init__(self, size): + self.data = np.full((size, 4), np.nan) + self.confidence = 0 # 0 by default, overwritten otherwise with `add_joint` + self._affinity = 0 + self._links = [] + self._visible = set() + self._idx = set() + self._dict = dict() + + def __len__(self): + return len(self._visible) + + def __contains__(self, assembly): + return bool(self._visible.intersection(assembly._visible)) + + def __add__(self, other): + if other in self: + raise ValueError("Assemblies contain shared joints.") + + assembly = Assembly(self.data.shape[0]) + for link in self._links + other._links: + assembly.add_link(link) + return assembly + + @classmethod + def from_array(cls, array): + n_bpts, n_cols = array.shape + + # if a single coordinate is NaN for a bodypart, set all to NaN + array[np.isnan(array).any(axis=-1)] = np.nan + + ass = cls(size=n_bpts) + ass.data[:, :n_cols] = array + visible = np.flatnonzero(~np.isnan(array).any(axis=1)) + if n_cols < 3: # Only xy coordinates are being set + ass.data[visible, 2] = 1 # Set detection confidence to 1 + ass._visible.update(visible) + return ass + + @property + def xy(self): + return self.data[:, :2] + + @property + def extent(self): + bbox = np.empty(4) + bbox[:2] = np.nanmin(self.xy, axis=0) + bbox[2:] = np.nanmax(self.xy, axis=0) + return bbox + + @property + def area(self): + x1, y1, x2, y2 = self.extent + return (x2 - x1) * (y2 - y1) + + @property + def confidence(self): + return np.nanmean(self.data[:, 2]) + + @confidence.setter + def confidence(self, confidence): + self.data[:, 2] = confidence + + @property + def soft_identity(self): + data = self.data[~np.isnan(self.data).any(axis=1)] + unq, idx, cnt = np.unique(data[:, 3], return_inverse=True, return_counts=True) + avg = np.bincount(idx, weights=data[:, 2]) / cnt + soft = softmax(avg) + return dict(zip(unq.astype(int), soft)) + + @property + def affinity(self): + n_links = self.n_links + if not n_links: + return 0 + return self._affinity / n_links + + @property + def n_links(self): + return len(self._links) + + def intersection_with(self, other): + x11, y11, x21, y21 = self.extent + x12, y12, x22, y22 = other.extent + x1 = max(x11, x12) + y1 = max(y11, y12) + x2 = min(x21, x22) + y2 = min(y21, y22) + if x2 < x1 or y2 < y1: + return 0 + ll = np.array([x1, y1]) + ur = np.array([x2, y2]) + xy1 = self.xy[~np.isnan(self.xy).any(axis=1)] + xy2 = other.xy[~np.isnan(other.xy).any(axis=1)] + in1 = np.all((xy1 >= ll) & (xy1 <= ur), axis=1).sum() + in2 = np.all((xy2 >= ll) & (xy2 <= ur), axis=1).sum() + return min(in1 / len(self), in2 / len(other)) + + def add_joint(self, joint): + if joint.label in self._visible or joint.label is None: + return False + self.data[joint.label] = *joint.pos, joint.confidence, joint.group + self._visible.add(joint.label) + self._idx.add(joint.idx) + return True + + def remove_joint(self, joint): + if joint.label not in self._visible: + return False + self.data[joint.label] = np.nan + self._visible.remove(joint.label) + self._idx.remove(joint.idx) + return True + + def add_link(self, link, store_dict=False): + if store_dict: + # Selective copy; deepcopy is >5x slower + self._dict = { + "data": self.data.copy(), + "_affinity": self._affinity, + "_links": self._links.copy(), + "_visible": self._visible.copy(), + "_idx": self._idx.copy(), + } + i1, i2 = link.idx + if i1 in self._idx and i2 in self._idx: + self._affinity += link.affinity + self._links.append(link) + return False + if link.j1.label in self._visible and link.j2.label in self._visible: + return False + self.add_joint(link.j1) + self.add_joint(link.j2) + self._affinity += link.affinity + self._links.append(link) + return True + + def calc_pairwise_distances(self): + return pdist(self.xy, metric="sqeuclidean") + + +class Assembler: + def __init__( + self, + data, + *, + max_n_individuals, + n_multibodyparts, + graph=None, + paf_inds=None, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=2, + max_overlap=0.8, + identity_only=False, + nan_policy="little", + force_fusion=False, + add_discarded=False, + window_size=0, + method="m1", + ): + self.data = data + self.metadata = self.parse_metadata(self.data) + self.max_n_individuals = max_n_individuals + self.n_multibodyparts = n_multibodyparts + self.n_uniquebodyparts = self.n_keypoints - n_multibodyparts + self.greedy = greedy + self.pcutoff = pcutoff + self.min_affinity = min_affinity + self.min_n_links = min_n_links + self.max_overlap = max_overlap + self._has_identity = "identity" in self[0] + if identity_only and not self._has_identity: + warnings.warn( + "The network was not trained with identity; setting `identity_only` to False." + ) + self.identity_only = identity_only & self._has_identity + self.nan_policy = nan_policy + self.force_fusion = force_fusion + self.add_discarded = add_discarded + self.window_size = window_size + self.method = method + self.graph = graph or self.metadata["paf_graph"] + self.paf_inds = paf_inds or self.metadata["paf"] + self._gamma = 0.01 + self._trees = dict() + self.safe_edge = False + self._kde = None + self.assemblies = dict() + self.unique = dict() + + def __getitem__(self, item): + return self.data[self.metadata["imnames"][item]] + + @classmethod + def empty( + cls, + max_n_individuals, + n_multibodyparts, + n_uniquebodyparts, + graph, + paf_inds, + greedy=False, + pcutoff=0.1, + min_affinity=0.05, + min_n_links=2, + max_overlap=0.8, + identity_only=False, + nan_policy="little", + force_fusion=False, + add_discarded=False, + window_size=0, + method="m1", + ): + # Dummy data + n_bodyparts = n_multibodyparts + n_uniquebodyparts + data = { + "metadata": { + "all_joints_names": ["" for _ in range(n_bodyparts)], + "PAFgraph": graph, + "PAFinds": paf_inds, + }, + "0": {}, + } + return cls( + data, + max_n_individuals=max_n_individuals, + n_multibodyparts=n_multibodyparts, + graph=graph, + paf_inds=paf_inds, + greedy=greedy, + pcutoff=pcutoff, + min_affinity=min_affinity, + min_n_links=min_n_links, + max_overlap=max_overlap, + identity_only=identity_only, + nan_policy=nan_policy, + force_fusion=force_fusion, + add_discarded=add_discarded, + window_size=window_size, + method=method, + ) + + @property + def n_keypoints(self): + return self.metadata["num_joints"] + + def calibrate(self, train_data_file): + df = pd.read_hdf(train_data_file) + try: + df.drop("single", level="individuals", axis=1, inplace=True) + except KeyError: + pass + n_bpts = len(df.columns.get_level_values("bodyparts").unique()) + if n_bpts == 1: + warnings.warn("There is only one keypoint; skipping calibration...") + return + + xy = df.to_numpy().reshape((-1, n_bpts, 2)) + frac_valid = np.mean(~np.isnan(xy), axis=(1, 2)) + # Only keeps skeletons that are more than 90% complete + xy = xy[frac_valid >= 0.9] + if not xy.size: + warnings.warn("No complete poses were found. Skipping calibration...") + return + + # TODO Normalize dists by longest length? + # TODO Smarter imputation technique (Bayesian? Grassmann averages?) + dists = np.vstack([pdist(data, "sqeuclidean") for data in xy]) + mu = np.nanmean(dists, axis=0) + missing = np.isnan(dists) + dists = np.where(missing, mu, dists) + try: + kde = gaussian_kde(dists.T) + kde.mean = mu + self._kde = kde + self.safe_edge = True + except np.linalg.LinAlgError: + # Covariance matrix estimation fails due to numerical singularities + warnings.warn( + "The assembler could not be robustly calibrated. Continuing without it..." + ) + + def calc_assembly_mahalanobis_dist( + self, assembly, return_proba=False, nan_policy="little" + ): + if self._kde is None: + raise ValueError("Assembler should be calibrated first with training data.") + + dists = assembly.calc_pairwise_distances() - self._kde.mean + mask = np.isnan(dists) + # Distance is undefined if the assembly is empty + if not len(assembly) or mask.all(): + if return_proba: + return np.inf, 0 + return np.inf + + if nan_policy == "little": + inds = np.flatnonzero(~mask) + dists = dists[inds] + inv_cov = self._kde.inv_cov[np.ix_(inds, inds)] + # Correct distance to account for missing observations + factor = self._kde.d / len(inds) + else: + # Alternatively, reduce contribution of missing values to the Mahalanobis + # distance to zero by substituting the corresponding means. + dists[mask] = 0 + mask.fill(False) + inv_cov = self._kde.inv_cov + factor = 1 + dot = dists @ inv_cov + mahal = factor * sqrt(np.sum((dot * dists), axis=-1)) + if return_proba: + proba = 1 - chi2.cdf(mahal, np.sum(~mask)) + return mahal, proba + return mahal + + def calc_link_probability(self, link): + if self._kde is None: + raise ValueError("Assembler should be calibrated first with training data.") + + i = link.j1.label + j = link.j2.label + ind = _conv_square_to_condensed_indices(i, j, self.n_multibodyparts) + mu = self._kde.mean[ind] + sigma = self._kde.covariance[ind, ind] + z = (link.length**2 - mu) / sigma + return 2 * (1 - 0.5 * (1 + erf(abs(z) / sqrt(2)))) + + @staticmethod + def _flatten_detections(data_dict): + ind = 0 + coordinates = data_dict["coordinates"][0] + confidence = data_dict["confidence"] + ids = data_dict.get("identity", None) + if ids is None: + ids = [np.ones(len(arr), dtype=int) * -1 for arr in confidence] + else: + ids = [arr.argmax(axis=1) for arr in ids] + for i, (coords, conf, id_) in enumerate(zip(coordinates, confidence, ids)): + if not np.any(coords): + continue + for xy, p, g in zip(coords, conf, id_): + joint = Joint(tuple(xy), p.item(), i, ind, g) + ind += 1 + yield joint + + def extract_best_links(self, joints_dict, costs, trees=None): + links = [] + for ind in self.paf_inds: + s, t = self.graph[ind] + dets_s = joints_dict.get(s, None) + dets_t = joints_dict.get(t, None) + if dets_s is None or dets_t is None: + continue + if ind not in costs: + continue + lengths = costs[ind]["distance"] + if np.isinf(lengths).all(): + continue + aff = costs[ind][self.method].copy() + aff[np.isnan(aff)] = 0 + + if trees: + vecs = np.vstack( + [[*det_s.pos, *det_t.pos] for det_s in dets_s for det_t in dets_t] + ) + dists = [] + for n, tree in enumerate(trees, start=1): + d, _ = tree.query(vecs) + dists.append(np.exp(-self._gamma * n * d)) + w = np.mean(dists, axis=0) + aff *= w.reshape(aff.shape) + + if self.greedy: + conf = np.asarray( + [ + [det_s.confidence * det_t.confidence for det_t in dets_t] + for det_s in dets_s + ] + ) + rows, cols = np.where( + (conf >= self.pcutoff * self.pcutoff) & (aff >= self.min_affinity) + ) + candidates = sorted( + zip(rows, cols, aff[rows, cols], lengths[rows, cols]), + key=lambda x: x[2], + reverse=True, + ) + i_seen = set() + j_seen = set() + for i, j, w, l in candidates: + if i not in i_seen and j not in j_seen: + i_seen.add(i) + j_seen.add(j) + links.append(Link(dets_s[i], dets_t[j], w)) + if len(i_seen) == self.max_n_individuals: + break + else: # Optimal keypoint pairing + inds_s = sorted( + range(len(dets_s)), key=lambda x: dets_s[x].confidence, reverse=True + )[: self.max_n_individuals] + inds_t = sorted( + range(len(dets_t)), key=lambda x: dets_t[x].confidence, reverse=True + )[: self.max_n_individuals] + keep_s = [ + ind for ind in inds_s if dets_s[ind].confidence >= self.pcutoff + ] + keep_t = [ + ind for ind in inds_t if dets_t[ind].confidence >= self.pcutoff + ] + aff = aff[np.ix_(keep_s, keep_t)] + rows, cols = linear_sum_assignment(aff, maximize=True) + for row, col in zip(rows, cols): + w = aff[row, col] + if w >= self.min_affinity: + links.append(Link(dets_s[keep_s[row]], dets_t[keep_t[col]], w)) + return links + + def _fill_assembly(self, assembly, lookup, assembled, safe_edge, nan_policy): + stack = [] + visited = set() + tabu = [] + counter = itertools.count() + + def push_to_stack(i): + for j, link in lookup[i].items(): + if j in assembly._idx: + continue + if link.idx in visited: + continue + heapq.heappush(stack, (-link.affinity, next(counter), link)) + visited.add(link.idx) + + for idx in assembly._idx: + push_to_stack(idx) + + while stack and len(assembly) < self.n_multibodyparts: + _, _, best = heapq.heappop(stack) + i, j = best.idx + if i in assembly._idx: + new_ind = j + elif j in assembly._idx: + new_ind = i + else: + continue + if new_ind in assembled: + continue + if safe_edge: + d_old = self.calc_assembly_mahalanobis_dist( + assembly, nan_policy=nan_policy + ) + success = assembly.add_link(best, store_dict=True) + if not success: + assembly._dict = dict() + continue + d = self.calc_assembly_mahalanobis_dist(assembly, nan_policy=nan_policy) + if d < d_old: + push_to_stack(new_ind) + try: + _, _, link = heapq.heappop(tabu) + heapq.heappush(stack, (-link.affinity, next(counter), link)) + except IndexError: + pass + else: + heapq.heappush(tabu, (d - d_old, next(counter), best)) + assembly.__dict__.update(assembly._dict) + assembly._dict = dict() + else: + assembly.add_link(best) + push_to_stack(new_ind) + + def build_assemblies(self, links): + lookup = defaultdict(dict) + for link in links: + i, j = link.idx + lookup[i][j] = link + lookup[j][i] = link + + assemblies = [] + assembled = set() + + # Fill the subsets with unambiguous, complete individuals + G = nx.Graph([link.idx for link in links]) + for chain in nx.connected_components(G): + if len(chain) == self.n_multibodyparts: + edges = [tuple(sorted(edge)) for edge in G.edges(chain)] + assembly = Assembly(self.n_multibodyparts) + for link in links: + i, j = link.idx + if (i, j) in edges: + success = assembly.add_link(link) + if success: + lookup[i].pop(j) + lookup[j].pop(i) + assembled.update(assembly._idx) + assemblies.append(assembly) + + if len(assemblies) == self.max_n_individuals: + return assemblies, assembled + + for link in sorted(links, key=lambda x: x.affinity, reverse=True): + if any(i in assembled for i in link.idx): + continue + assembly = Assembly(self.n_multibodyparts) + assembly.add_link(link) + self._fill_assembly( + assembly, lookup, assembled, self.safe_edge, self.nan_policy + ) + for link in assembly._links: + i, j = link.idx + lookup[i].pop(j) + lookup[j].pop(i) + assembled.update(assembly._idx) + assemblies.append(assembly) + + # Fuse superfluous assemblies + n_extra = len(assemblies) - self.max_n_individuals + if n_extra > 0: + if self.safe_edge: + ds_old = [ + self.calc_assembly_mahalanobis_dist(assembly) + for assembly in assemblies + ] + while len(assemblies) > self.max_n_individuals: + ds = [] + for i, j in itertools.combinations(range(len(assemblies)), 2): + if assemblies[j] not in assemblies[i]: + temp = assemblies[i] + assemblies[j] + d = self.calc_assembly_mahalanobis_dist(temp) + delta = d - max(ds_old[i], ds_old[j]) + ds.append((i, j, delta, d, temp)) + if not ds: + break + min_ = sorted(ds, key=lambda x: x[2]) + i, j, delta, d, new = min_[0] + if delta < 0 or len(min_) == 1: + assemblies[i] = new + assemblies.pop(j) + ds_old[i] = d + ds_old.pop(j) + else: + break + elif self.force_fusion: + assemblies = sorted(assemblies, key=len) + for nrow in range(n_extra): + assembly = assemblies[nrow] + candidates = [a for a in assemblies[nrow:] if assembly not in a] + if not candidates: + continue + if len(candidates) == 1: + candidate = candidates[0] + else: + dists = [] + for cand in candidates: + d = cdist(assembly.xy, cand.xy) + dists.append(np.nanmin(d)) + candidate = candidates[np.argmin(dists)] + ind = assemblies.index(candidate) + assemblies[ind] += assembly + else: + store = dict() + for assembly in assemblies: + if len(assembly) != self.n_multibodyparts: + for i in assembly._idx: + store[i] = assembly + used = [link for assembly in assemblies for link in assembly._links] + unconnected = [link for link in links if link not in used] + for link in unconnected: + i, j = link.idx + try: + if store[j] not in store[i]: + temp = store[i] + store[j] + store[i].__dict__.update(temp.__dict__) + assemblies.remove(store[j]) + for idx in store[j]._idx: + store[idx] = store[i] + except KeyError: + pass + + # Second pass without edge safety + for assembly in assemblies: + if len(assembly) != self.n_multibodyparts: + self._fill_assembly(assembly, lookup, assembled, False, "") + assembled.update(assembly._idx) + + return assemblies, assembled + + def _assemble(self, data_dict, ind_frame): + joints = list(self._flatten_detections(data_dict)) + if not joints: + return None, None + + bag = defaultdict(list) + for joint in joints: + bag[joint.label].append(joint) + + assembled = set() + + if self.n_uniquebodyparts: + unique = np.full((self.n_uniquebodyparts, 3), np.nan) + for n, ind in enumerate(range(self.n_multibodyparts, self.n_keypoints)): + dets = bag[ind] + if not dets: + continue + if len(dets) > 1: + det = max(dets, key=lambda x: x.confidence) + else: + det = dets[0] + # Mark the unique body parts as assembled anyway so + # they are not used later on to fill assemblies. + assembled.update(d.idx for d in dets) + if det.confidence <= self.pcutoff and not self.add_discarded: + continue + unique[n] = *det.pos, det.confidence + if np.isnan(unique).all(): + unique = None + else: + unique = None + + if not any(i in bag for i in range(self.n_multibodyparts)): + return None, unique + + if self.n_multibodyparts == 1: + assemblies = [] + for joint in bag[0]: + if joint.confidence >= self.pcutoff: + ass = Assembly(self.n_multibodyparts) + ass.add_joint(joint) + assemblies.append(ass) + return assemblies, unique + + if self.max_n_individuals == 1: + get_attr = operator.attrgetter("confidence") + ass = Assembly(self.n_multibodyparts) + for ind in range(self.n_multibodyparts): + joints = bag[ind] + if not joints: + continue + ass.add_joint(max(joints, key=get_attr)) + return [ass], unique + + if self.identity_only: + assemblies = [] + get_attr = operator.attrgetter("group") + temp = sorted( + (joint for joint in joints if np.isfinite(joint.confidence)), + key=get_attr, + ) + groups = itertools.groupby(temp, get_attr) + for _, group in groups: + ass = Assembly(self.n_multibodyparts) + for joint in sorted(group, key=lambda x: x.confidence, reverse=True): + if ( + joint.confidence >= self.pcutoff + and joint.label < self.n_multibodyparts + ): + ass.add_joint(joint) + if len(ass): + assemblies.append(ass) + assembled.update(ass._idx) + else: + trees = [] + for j in range(1, self.window_size + 1): + tree = self._trees.get(ind_frame - j, None) + if tree is not None: + trees.append(tree) + + links = self.extract_best_links(bag, data_dict["costs"], trees) + if self._kde: + for link in links[::-1]: + p = max(self.calc_link_probability(link), 0.001) + link.affinity *= p + if link.affinity < self.min_affinity: + links.remove(link) + + if self.window_size >= 1 and links: + # Store selected edges for subsequent frames + vecs = np.vstack([link.to_vector() for link in links]) + self._trees[ind_frame] = cKDTree(vecs) + + assemblies, assembled_ = self.build_assemblies(links) + assembled.update(assembled_) + + # Remove invalid assemblies + discarded = set( + joint + for joint in joints + if joint.idx not in assembled and np.isfinite(joint.confidence) + ) + for assembly in assemblies[::-1]: + if 0 < assembly.n_links < self.min_n_links or not len(assembly): + for link in assembly._links: + discarded.update((link.j1, link.j2)) + assemblies.remove(assembly) + if 0 < self.max_overlap < 1: # Non-maximum pose suppression + if self._kde is not None: + scores = [ + -self.calc_assembly_mahalanobis_dist(ass) for ass in assemblies + ] + else: + scores = [ass._affinity for ass in assemblies] + lst = list(zip(scores, assemblies)) + assemblies = [] + while lst: + temp = max(lst, key=lambda x: x[0]) + lst.remove(temp) + assemblies.append(temp[1]) + for pair in lst[::-1]: + if temp[1].intersection_with(pair[1]) >= self.max_overlap: + lst.remove(pair) + if len(assemblies) > self.max_n_individuals: + assemblies = sorted(assemblies, key=len, reverse=True) + for assembly in assemblies[self.max_n_individuals :]: + for link in assembly._links: + discarded.update((link.j1, link.j2)) + assemblies = assemblies[: self.max_n_individuals] + + if self.add_discarded and discarded: + # Fill assemblies with unconnected body parts + for joint in sorted(discarded, key=lambda x: x.confidence, reverse=True): + if self.safe_edge: + for assembly in assemblies: + if joint.label in assembly._visible: + continue + d_old = self.calc_assembly_mahalanobis_dist(assembly) + assembly.add_joint(joint) + d = self.calc_assembly_mahalanobis_dist(assembly) + if d < d_old: + break + assembly.remove_joint(joint) + else: + dists = [] + for i, assembly in enumerate(assemblies): + if joint.label in assembly._visible: + continue + d = cdist(assembly.xy, np.atleast_2d(joint.pos)) + dists.append((i, np.nanmin(d))) + if not dists: + continue + min_ = sorted(dists, key=lambda x: x[1]) + ind, _ = min_[0] + assemblies[ind].add_joint(joint) + + return assemblies, unique + + def assemble(self, chunk_size=1, n_processes=None): + self.assemblies = dict() + self.unique = dict() + # Spawning (rather than forking) multiple processes does not + # work nicely with the GUI or interactive sessions. + # In that case, we fall back to the serial assembly. + if chunk_size == 0 or multiprocessing.get_start_method() == "spawn": + + for i, data_dict in enumerate(tqdm(self)): + assemblies, unique = self._assemble(data_dict, i) + if assemblies: + self.assemblies[i] = assemblies + if unique is not None: + self.unique[i] = unique + else: + global wrapped # Hack to make the function pickable + + def wrapped(i): + return i, self._assemble(self[i], i) + + n_frames = len(self.metadata["imnames"]) + with multiprocessing.Pool(n_processes) as p: + with tqdm(total=n_frames) as pbar: + for i, (assemblies, unique) in p.imap_unordered( + wrapped, range(n_frames), chunksize=chunk_size + ): + if assemblies: + self.assemblies[i] = assemblies + if unique is not None: + self.unique[i] = unique + pbar.update() + + def from_pickle(self, pickle_path): + with open(pickle_path, "rb") as file: + data = pickle.load(file) + self.unique = data.pop("single", {}) + self.assemblies = data + + @staticmethod + def parse_metadata(data): + params = dict() + params["joint_names"] = data["metadata"]["all_joints_names"] + params["num_joints"] = len(params["joint_names"]) + params["paf_graph"] = data["metadata"]["PAFgraph"] + params["paf"] = data["metadata"].get( + "PAFinds", np.arange(len(params["joint_names"])) + ) + params["bpts"] = params["ibpts"] = range(params["num_joints"]) + params["imnames"] = [fn for fn in list(data) if fn != "metadata"] + return params + + def to_h5(self, output_name): + data = np.full( + ( + len(self.metadata["imnames"]), + self.max_n_individuals, + self.n_multibodyparts, + 4, + ), + fill_value=np.nan, + ) + for ind, assemblies in self.assemblies.items(): + for n, assembly in enumerate(assemblies): + data[ind, n] = assembly.data + index = pd.MultiIndex.from_product( + [ + ["scorer"], + map(str, range(self.max_n_individuals)), + map(str, range(self.n_multibodyparts)), + ["x", "y", "likelihood"], + ], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + temp = data[..., :3].reshape((data.shape[0], -1)) + df = pd.DataFrame(temp, columns=index) + df.to_hdf(output_name, key="ass") + + def to_pickle(self, output_name): + data = dict() + for ind, assemblies in self.assemblies.items(): + data[ind] = [ass.data for ass in assemblies] + if self.unique: + data["single"] = self.unique + with open(output_name, "wb") as file: + pickle.dump(data, file, pickle.HIGHEST_PROTOCOL) + + +@dataclass +class MatchedPrediction: + """A match between a prediction and a ground truth assembly + + The ground truth assembly should be None f the prediction was not matched to any GT, + and the OKS should be 0. + + Attributes: + prediction: A prediction made by a pose model. + score: The confidence score for the prediction. + ground_truth: If None, then this prediction is not matched to any ground truth + (this can happen when there are more predicted individuals than GT). + Otherwise, the ground truth assembly to which this prediction is matched. + oks: The OKS score between the prediction and the ground truth pose. + """ + + prediction: Assembly + score: float + ground_truth: Assembly | None + oks: float + + +def calc_object_keypoint_similarity( + xy_pred, + xy_true, + sigma, + margin=0, + symmetric_kpts=None, +): + visible_gt = ~np.isnan(xy_true).all(axis=1) + if visible_gt.sum() < 2: # At least 2 points needed to calculate scale + return np.nan + + true = xy_true[visible_gt] + scale_squared = np.prod(np.ptp(true, axis=0) + np.spacing(1) + margin * 2) + if np.isclose(scale_squared, 0): + return np.nan + + k_squared = (2 * sigma) ** 2 + denom = 2 * scale_squared * k_squared + if symmetric_kpts is None: + pred = xy_pred[visible_gt] + pred[np.isnan(pred)] = np.inf + dist_squared = np.sum((pred - true) ** 2, axis=1) + oks = np.exp(-dist_squared / denom) + return np.mean(oks) + else: + oks = [] + xy_preds = [xy_pred] + combos = ( + pair + for l in range(len(symmetric_kpts)) + for pair in itertools.combinations(symmetric_kpts, l + 1) + ) + for pairs in combos: + # Swap corresponding keypoints + tmp = xy_pred.copy() + for pair in pairs: + tmp[pair, :] = tmp[pair[::-1], :] + xy_preds.append(tmp) + for xy_pred in xy_preds: + pred = xy_pred[visible_gt] + pred[np.isnan(pred)] = np.inf + dist_squared = np.sum((pred - true) ** 2, axis=1) + oks.append(np.mean(np.exp(-dist_squared / denom))) + return max(oks) + + +def match_assemblies( + predictions: list[Assembly], + ground_truth: list[Assembly], + sigma: float, + margin: int = 0, + symmetric_kpts: list[tuple[int, int]] | None = None, + greedy_matching: bool = False, + greedy_oks_threshold: float = 0.0, +) -> tuple[int, list[MatchedPrediction]]: + """Matches assemblies to ground truth predictions + + Returns: + int: the total number of valid ground truth assemblies + list[MatchedPrediction]: a list containing all valid predictions, potentially + matched to ground truth assemblies. + """ + # Only consider assemblies of at least two keypoints + predictions = [a for a in predictions if len(a) > 1] + ground_truth = [a for a in ground_truth if len(a) > 1] + num_ground_truth = len(ground_truth) + + # Sort predictions by score + inds_pred = np.argsort( + [ins.affinity if ins.n_links else ins.confidence for ins in predictions] + )[::-1] + predictions = np.asarray(predictions)[inds_pred] + + # indices of unmatched ground truth assemblies + matched = [ + MatchedPrediction( + prediction=p, + score=(p.affinity if p.n_links else p.confidence), + ground_truth=None, + oks=0.0, + ) + for p in predictions + ] + + # Greedy assembly matching like in pycocotools + if greedy_matching: + matched_gt_indices = set() + for idx, pred in enumerate(predictions): + oks = [ + calc_object_keypoint_similarity( + pred.xy, + gt.xy, + sigma, + margin, + symmetric_kpts, + ) + for gt in ground_truth + ] + if np.all(np.isnan(oks)): + continue + + ind_best = np.nanargmax(oks) + + # if this gt already matched, and not a crowd, continue + if ind_best in matched_gt_indices: + continue + + # Only match the pred to the GT if the OKS value is above a given threshold + if oks[ind_best] < greedy_oks_threshold: + continue + + matched_gt_indices.add(ind_best) + matched[idx].ground_truth = ground_truth[ind_best] + matched[idx].oks = oks[ind_best] + + # Global rather than greedy assembly matching + else: + inds_true = list(range(len(ground_truth))) + mat = np.zeros((len(predictions), len(ground_truth))) + for i, a_pred in enumerate(predictions): + for j, a_true in enumerate(ground_truth): + oks = calc_object_keypoint_similarity( + a_pred.xy, + a_true.xy, + sigma, + margin, + symmetric_kpts, + ) + if ~np.isnan(oks): + mat[i, j] = oks + rows, cols = linear_sum_assignment(mat, maximize=True) + for row, col in zip(rows, cols): + matched[row].ground_truth = ground_truth[col] + matched[row].oks = mat[row, col] + _ = inds_true.remove(col) + + return num_ground_truth, matched + + +def parse_ground_truth_data_file(h5_file): + df = pd.read_hdf(h5_file) + try: + df.drop("single", axis=1, level="individuals", inplace=True) + except KeyError: + pass + # Cast columns of dtype 'object' to float to avoid TypeError + # further down in _parse_ground_truth_data. + cols = df.select_dtypes(include="object").columns + if cols.to_list(): + df[cols] = df[cols].astype("float") + n_individuals = len(df.columns.get_level_values("individuals").unique()) + n_bodyparts = len(df.columns.get_level_values("bodyparts").unique()) + data = df.to_numpy().reshape((df.shape[0], n_individuals, n_bodyparts, -1)) + return _parse_ground_truth_data(data) + + +def _parse_ground_truth_data(data): + gt = dict() + for i, arr in enumerate(data): + temp = [] + for row in arr: + if np.isnan(row[:, :2]).all(): + continue + ass = Assembly.from_array(row) + temp.append(ass) + if not temp: + continue + gt[i] = temp + return gt + + +def find_outlier_assemblies(dict_of_assemblies, criterion="area", qs=(5, 95)): + if not hasattr(Assembly, criterion): + raise ValueError(f"Invalid criterion {criterion}.") + + if len(qs) != 2: + raise ValueError( + "Two percentiles (for lower and upper bounds) should be given." + ) + + tuples = [] + for frame_ind, assemblies in dict_of_assemblies.items(): + for assembly in assemblies: + tuples.append((frame_ind, getattr(assembly, criterion))) + frame_inds, vals = zip(*tuples) + vals = np.asarray(vals) + lo, up = np.percentile(vals, qs, interpolation="nearest") + inds = np.flatnonzero((vals < lo) | (vals > up)).tolist() + return list(set(frame_inds[i] for i in inds)) + + +def _compute_precision_and_recall( + num_gt_assemblies: int, + oks_values: np.ndarray, + oks_threshold: float, + recall_thresholds: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Computes the precision and recall scores at a given OKS threshold + + Args: + num_gt_assemblies: the number of ground truth assemblies (used to compute false + negatives + true positives). + oks_values: the OKS value to the matched GT assembly for each prediction + oks_threshold: the OKS threshold at which recall and precision are being + computed + recall_thresholds: the recall thresholds to use to compute scores + + Returns: + The precision and recall arrays at each recall threshold + """ + tp = np.cumsum(oks_values >= oks_threshold) + fp = np.cumsum(oks_values < oks_threshold) + rc = tp / num_gt_assemblies + pr = tp / (fp + tp + np.spacing(1)) + recall = rc[-1] + + # Guarantee precision decreases monotonically, see + # https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173 + for i in range(len(pr) - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds_rc = np.searchsorted(rc, recall_thresholds, side="left") + precision = np.zeros(inds_rc.shape) + valid = inds_rc < len(pr) + precision[valid] = pr[inds_rc[valid]] + return precision, recall + + +def evaluate_assembly_greedy( + assemblies_gt: dict[Any, list[Assembly]], + assemblies_pred: dict[Any, list[Assembly]], + oks_sigma: float, + oks_thresholds: Iterable[float], + margin: int | float = 0, + symmetric_kpts: list[tuple[int, int]] | None = None, +) -> dict: + """Runs greedy mAP evaluation, as done by pycocotools + + Args: + assemblies_gt: A dictionary mapping image ID (e.g. filepath) to ground truth + assemblies. Should contain all the same keys as ``assemblies_pred``. + assemblies_pred: A dictionary mapping image ID (e.g. filepath) to predicted + assemblies. Should contain all the same keys as ``assemblies_gt``. + oks_sigma: The sigma to use to compute OKS values for keypoints . + oks_thresholds: The OKS thresholds at which to compute precision & recall. + margin: The margin to use to compute bounding boxes from keypoints. + symmetric_kpts: The symmetric keypoints in the dataset. + """ + recall_thresholds = np.linspace( # np.linspace(0, 1, 101) + start=0.0, stop=1.00, num=int(np.round((1.00 - 0.0) / 0.01)) + 1, endpoint=True + ) + precisions = [] + recalls = [] + for oks_t in oks_thresholds: + all_matched = [] + total_gt_assemblies = 0 + for ind, gt_assembly in assemblies_gt.items(): + pred_assemblies = assemblies_pred.get(ind, []) + num_gt_assemblies, matched = match_assemblies( + pred_assemblies, + gt_assembly, + oks_sigma, + margin, + symmetric_kpts, + greedy_matching=True, + greedy_oks_threshold=oks_t, + ) + all_matched.extend(matched) + total_gt_assemblies += num_gt_assemblies + + if len(all_matched) == 0: + precisions.append(0.0) + recalls.append(0.0) + continue + + # Global sort of assemblies (across all images) by score + scores = np.asarray([-m.score for m in all_matched]) + sorted_pred_indices = np.argsort(scores, kind="mergesort") + oks = np.asarray([match.oks for match in all_matched])[sorted_pred_indices] + + # Compute prediction and recall + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, oks_t, recall_thresholds + ) + precisions.append(p) + recalls.append(r) + + precisions = np.asarray(precisions) + recalls = np.asarray(recalls) + return { + "precisions": precisions, + "recalls": recalls, + "mAP": precisions.mean(), + "mAR": recalls.mean(), + } + + +def evaluate_assembly( + ass_pred_dict, + ass_true_dict, + oks_sigma=0.072, + oks_thresholds=np.linspace(0.5, 0.95, 10), + margin=0, + symmetric_kpts=None, + greedy_matching=False, + with_tqdm: bool = True, +): + if greedy_matching: + return evaluate_assembly_greedy( + ass_true_dict, + ass_pred_dict, + oks_sigma=oks_sigma, + oks_thresholds=oks_thresholds, + margin=margin, + symmetric_kpts=symmetric_kpts, + ) + + # sigma is taken as the median of all COCO keypoint standard deviations + all_matched = [] + total_gt_assemblies = 0 + + gt_assemblies = ass_true_dict.items() + if with_tqdm: + gt_assemblies = tqdm(gt_assemblies) + + for ind, gt_assembly in gt_assemblies: + pred_assemblies = ass_pred_dict.get(ind, []) + num_gt, matched = match_assemblies( + pred_assemblies, + gt_assembly, + oks_sigma, + margin, + symmetric_kpts, + greedy_matching, + ) + all_matched.extend(matched) + total_gt_assemblies += num_gt + + if not all_matched: + return { + "precisions": np.array([]), + "recalls": np.array([]), + "mAP": 0.0, + "mAR": 0.0, + } + + conf_pred = np.asarray([match.score for match in all_matched]) + idx = np.argsort(-conf_pred, kind="mergesort") + # Sort matching score (OKS) in descending order of assembly affinity + oks = np.asarray([match.oks for match in all_matched])[idx] + recall_thresholds = np.linspace(0, 1, 101) + precisions = [] + recalls = [] + for t in oks_thresholds: + p, r = _compute_precision_and_recall( + total_gt_assemblies, oks, t, recall_thresholds + ) + precisions.append(p) + recalls.append(r) + + precisions = np.asarray(precisions) + recalls = np.asarray(recalls) + return { + "precisions": precisions, + "recalls": recalls, + "mAP": precisions.mean(), + "mAR": recalls.mean(), + } diff --git a/deeplabcut/core/metrics/__init__.py b/deeplabcut/core/metrics/__init__.py new file mode 100644 index 0000000000..94397de57a --- /dev/null +++ b/deeplabcut/core/metrics/__init__.py @@ -0,0 +1,13 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +from .api import compute_metrics, prepare_evaluation_data +from .bbox import compute_bbox_metrics +from .identity import compute_identity_scores diff --git a/deeplabcut/core/metrics/api.py b/deeplabcut/core/metrics/api.py new file mode 100644 index 0000000000..75d4e7bbcb --- /dev/null +++ b/deeplabcut/core/metrics/api.py @@ -0,0 +1,176 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""API methods to get metrics for deep learning models""" +from __future__ import annotations + +import numpy as np + +import deeplabcut.core.metrics.distance_metrics as distance_metrics + + +def compute_metrics( + ground_truth: dict[str, np.ndarray], + predictions: dict[str, np.ndarray], + single_animal: bool = False, + unique_bodypart_gt: dict[str, np.ndarray] | None = None, + unique_bodypart_poses: dict[str, np.ndarray] | None = None, + pcutoff: float = -1, + oks_bbox_margin: int = 0, + oks_sigma: float = 0.1, + per_keypoint_rmse: bool = False, + compute_detection_rmse: bool = True, +) -> dict: + """Computes pose estimation performance metrics + + Given ground truth pose labels and predictions on a dataset, computes RMSE and pose + mAP/mAR using OKS. + + The image paths in the ground_truth dict must be the same as the ones in the + predictions dict. + + Single animal RMSE is computed by simply calculating the Euclidean distance between + each ground truth keypoint and the corresponding prediction. + + Multi-animal RMSE is computed differently: predictions are first matched to ground + truth individuals using greedy OKS matching. OKS (or object keypoint similarity) is + a similarity metric for keypoints (you can read more about it and its definition + here: https://cocodataset.org/#keypoints-eval). RMSE is then computed only between + predictions and the ground truth pose they are matched to, only when the OKS is + greater than a small threshold. Predictions that cannot be matched to any ground + truth with non-zero OKS are not used to compute RMSE. + + Args: + ground_truth: The ground truth pose for which to compute metrics in the dataset. + This should be a dictionary mapping strings (image UIDs, such as image + paths) to ground truth pose for the image. The pose arrays should be + in the format (num_individuals, num_bodyparts, 3), where the 3 values are + x, y and visibility. The ``num_individuals`` corresponds to the number of + individuals labeled in each image. + predictions: The predicted poses for which to compute metrics in the dataset. + This should be a dictionary mapping strings (image UIDs, such as image + paths) to pose predictions for the image. The pose arrays should be + in the format (num_predictions, num_bodyparts, 3), where the 3 values are + x, y and score. The number of predictions can be different to the number of + ground truth individuals labeled for an image. + single_animal: Whether the metrics are being computed on a single-animal or + multi-animal dataset. This has an impact on RMSE computation. + unique_bodypart_gt: If unique bodyparts are defined for the dataset, they should + be contained in this dict in the same format as the ``ground_truth`` dict. + unique_bodypart_poses: If unique bodyparts are defined for the dataset, the + predictions should be contained in this dict in the same format as the + ``predictions`` dict. + pcutoff: The threshold to compute the "rmse_cutoff" score (RMSE of all + predictions with score above the cutoff). + oks_bbox_margin: The margin to add around keypoints to compute the area for OKS + computation. + oks_sigma: The OKS sigma to use to compute pose. + per_keypoint_rmse: Compute per-keypoint RMSE values. + compute_detection_rmse: Computes detection RMSE (without animal assembly) if the + predictions are from a multi-animal model. + + Returns: + A dictionary containing keys "rmse", "rmse_cutoff", "mAP" and "mAR" mapping + to those metrics on the given dataset. + + If unique bodyparts are given, two extra keys "rmse_unique_bodyparts" and + "rmse_pcutoff_unique_bodyparts" are also returned, containing the metrics for + the unique bodyparts head. + + If `per_keypoint_evaluation=True`, "keypoint_rmse", "keypoint_rmse_cutoff" (and + optionally "unique_keypoint_rmse" and "unique_keypoint_rmse_cutoff") keys are + added, containing a list of floats representing the RMSE for each keypoint. + + Examples: + >>> # Define the p-cutoff, prediction, and target DataFrames + >>> pcutoff = 0.5 + >>> ground_truth = {"img0": np.array([[[1.0, 1.0, 2.0], ...], ...]), ...} + >>> predictions = {"img0": np.array([[[2.0, 1.0, 0.4], ...], ...]), ...} + >>> scores = compute_metrics(ground_truth, predictions, pcutoff=pcutoff) + >>> print(scores) + { + "rmse": 1.0, + "rmse_pcutoff": 0.0, + 'mAP': 84.2, + 'mAR': 74.5 + } # Sample output scores + """ + data = prepare_evaluation_data(ground_truth, predictions) + oks_scores = distance_metrics.compute_oks( + data=data, + oks_sigma=oks_sigma, + oks_bbox_margin=oks_bbox_margin, + ) + + data_unique = None + if unique_bodypart_gt is not None: + assert unique_bodypart_poses is not None + data_unique = prepare_evaluation_data(unique_bodypart_gt, unique_bodypart_poses) + + rmse_scores = distance_metrics.compute_rmse( + data, + single_animal, + pcutoff, + data_unique=data_unique, + per_keypoint_results=per_keypoint_rmse, + ) + results = dict(**rmse_scores, **oks_scores) + + if compute_detection_rmse and not single_animal: + det_rmse, det_rmse_p = distance_metrics.compute_detection_rmse( + data, pcutoff, data_unique=data_unique, + ) + results["rmse_detections"] = det_rmse + results["rmse_detections_pcutoff"] = det_rmse_p + + return results + + +def prepare_evaluation_data( + ground_truth: dict[str, np.ndarray], + predictions: dict[str, np.ndarray], +) -> list[tuple[np.ndarray, np.ndarray]]: + """Prepares predictions and ground truth pose to compute metrics. + + Only keeps ground truth and predicted assemblies with at least 2 valid keypoints. + Sets the coordinates for all keypoints that aren't visible (for ground truth, + visibility <= 0 and for predictions score <= 0) to ``np.nan``. + + Sorts valid predictions by score. + + Args: + ground_truth: For each image, the GT of shape (n_idv, n_bpt, 3). + predictions: For each image, the pose predictions of shape (n_pred, n_bpt, 3). + + Returns: + A list containing (ground truth pose, predicted pose) for each image in the + dataset, where the predicted pose is sorted from highest to lowest score. + """ + pose_data = [] + for image, gt in ground_truth.items(): + gt = gt.copy() + gt[gt[..., 2] <= 0] = np.nan + + # only keep ground truth pose with at least one keypoint + gt_mask = np.any(np.all(~np.isnan(gt), axis=-1), axis=-1) + gt = gt[gt_mask] + + pred = predictions[image][..., :3].copy() # PAF have 5 values; keep xy + score + pred[pred[..., 2] < 0] = np.nan + + # only keep predicted pose with at least two keypoints + pred_mask = np.any(np.all(~np.isnan(pred), axis=-1), axis=-1) + pred = pred[pred_mask] + + scores = np.nanmean(pred[:, :, 2], axis=-1) + pred_order = np.argsort(-scores, kind="mergesort") + pose_data.append((gt, pred[pred_order])) + + return pose_data diff --git a/deeplabcut/core/metrics/bbox.py b/deeplabcut/core/metrics/bbox.py new file mode 100644 index 0000000000..2ee60cdc86 --- /dev/null +++ b/deeplabcut/core/metrics/bbox.py @@ -0,0 +1,159 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Bounding box metrics + +Metrics are currently computed using pycocotools, which can be installed with `pypi` +(see https://github.com/ppwwyyxx/cocoapi/tree/master). +""" +from __future__ import annotations + +from unittest.mock import Mock, patch + +import numpy as np + +try: + from pycocotools.coco import COCO + from pycocotools.cocoeval import COCOeval + + with_pycocotools = True +except ModuleNotFoundError as err: + with_pycocotools = False + + +@patch("pycocotools.coco.print", Mock()) +@patch("pycocotools.cocoeval.print", Mock()) +def compute_bbox_metrics( + ground_truth: dict[str, dict], + detections: dict[str, dict], +) -> dict[str, float]: + """Computes bbox mAP and mAR metrics for bounding boxes. + + Args: + ground_truth: A dictionary mapping image UIDs (such as image paths or filenames) + to a ground truth labels dict. The labels dict should contain the keys + "width" (image width), "height" (image height) and "bboxes" (a numpy array + of shape (num_gt_bboxes, 4) containing the ground truth bounding boxes in + format xywh). + detections: A dictionary mapping image UIDs (such as image paths or filenames) + to a predicted bounding box dict. The detections dict should contain the + keys "bboxes" (a numpy array of shape (num_detected_bboxes, 4) containing + the predicted bounding boxes in format xywh) and "scores" (a numpy array of + length num_detected_bboxes containing the confidence score for each + predicted bounding box). + + Returns: + The bounding box mAP/mAR metrics in a dictionary. + + Raises: + ModuleNotFoundError: if ``pycocotools`` is not installed + ValueError: if there are mismatches in the keys of ground_truth and detections + """ + if not with_pycocotools: + raise ModuleNotFoundError("pycocotools not installed! can't compute bbox mAP") + + if len(detections) != len(ground_truth): + raise ValueError() + + coco = COCO() + coco.dataset["annotations"] = [] + coco.dataset["categories"] = [{"id": 1, "name": "animals", "supercategory": "obj"}] + coco.dataset["images"] = [] + predictions = [] + for idx, (img, gt) in enumerate(ground_truth.items()): + img_id = idx + 1 + coco.dataset["images"].append( + { + "id": img_id, + "file_name": img, + "width": gt["width"], + "height": gt["height"], + } + ) + for bbox in gt["bboxes"][:, :4]: + ann_id = len(coco.dataset["annotations"]) + 1 + coco.dataset["annotations"].append( + { + "id": ann_id, + "image_id": img_id, + "category_id": 1, + "area": max(1, (bbox[2] * bbox[3]).item()), + "bbox": bbox, + "iscrowd": 0, + } + ) + + for bbox, score in zip(detections[img]["bboxes"], detections[img]["scores"]): + predictions.append(np.array([img_id, *bbox, score, 1])) + + if len(predictions) == 0: + return { + "mAP@50:95": 0.0, + "mAP@50": 0.0, + "mAP@75": 0.0, + "mAR@50:95": 0.0, + "mAR@50": 0.0, + "mAR@75": 0.0, + } + + predictions = np.stack(predictions, axis=0) + coco.createIndex() + coco_det = coco.loadRes(predictions) + coco_eval = COCOeval(coco, coco_det, iouType="bbox") + coco_eval.evaluate() + coco_eval.accumulate() + return { + name: val + for name, val in [ + _get_metric(coco_eval, recall=False), + _get_metric(coco_eval, recall=False, iou_threshold=0.5), + _get_metric(coco_eval, recall=False, iou_threshold=0.75), + _get_metric(coco_eval, recall=True), + _get_metric(coco_eval, recall=True, iou_threshold=0.5), + _get_metric(coco_eval, recall=True, iou_threshold=0.75), + ] + } + + +def _get_metric( + coco_eval: COCOeval, + recall: bool = False, + iou_threshold: float | None = None, + area_rng: str = "all", + max_dets: int = 100, +) -> tuple[str, float]: + metric_name = "mAR" if recall else "mAP" + if iou_threshold is not None: + thresh = f"{int(100 * iou_threshold)}" + else: + low, high = coco_eval.params.iouThrs[0], coco_eval.params.iouThrs[-1] + thresh = f"{int(100 * low)}:{int(100 * high)}" + + aind = [i for i, aRng in enumerate(coco_eval.params.areaRngLbl) if aRng == area_rng] + mind = [i for i, mDet in enumerate(coco_eval.params.maxDets) if mDet == max_dets] + if recall: + s = coco_eval.eval["recall"] + if iou_threshold is not None: + t = np.where(iou_threshold == coco_eval.params.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + else: + s = coco_eval.eval["precision"] + if iou_threshold is not None: + t = np.where(iou_threshold == coco_eval.params.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = 100 * np.mean(s[s > -1]).item() + + return f"{metric_name}@{thresh}", mean_s diff --git a/deeplabcut/core/metrics/distance_metrics.py b/deeplabcut/core/metrics/distance_metrics.py new file mode 100644 index 0000000000..ac2fbc04c1 --- /dev/null +++ b/deeplabcut/core/metrics/distance_metrics.py @@ -0,0 +1,459 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Implementations of methods to compute distance metrics such as RMSE or OKS""" +from __future__ import annotations + +import numpy as np + +import deeplabcut.core.metrics.matching as matching +from deeplabcut.core.crossvalutils import find_closest_neighbors +from deeplabcut.core.inferenceutils import calc_object_keypoint_similarity + + +def compute_oks_matrix( + ground_truth: np.ndarray, + predictions: np.ndarray, + oks_sigma: float, + oks_bbox_margin: float = 0.0, +) -> np.ndarray: + """Computes the OKS score for each (prediction, gt) pair in an image + + Args: + ground_truth: The GT poses for an image, shape (n_individuals, n_kpts, 2) + predictions: The predicted poses in the image, shape (n_pred, n_kpts, 2) + oks_sigma: The sigma value to use to compute OKS + oks_bbox_margin: The margin to add around keypoints when computing the area. + FIXME(niels) We should allow the use of ground truth bboxes to get area + + Returns: + A matrix of shape (n_pred, n_kpts) where entry (i, j) is the OKS between + prediction i and ground truth j. + """ + oks_matrix = np.zeros((len(predictions), len(ground_truth))) + for pred_idx, pred in enumerate(predictions): + for gt_idx, gt in enumerate(ground_truth): + oks_matrix[pred_idx, gt_idx] = calc_object_keypoint_similarity( + pred[:, :2], + gt[:, :2], + sigma=oks_sigma, + margin=oks_bbox_margin, + ) + + return oks_matrix + + +def compute_oks( + data: list[tuple[np.ndarray, np.ndarray]], + oks_bbox_margin: float = 0.0, + oks_sigma: float = 0.1, + oks_thresholds: np.ndarray | None = None, + oks_recall_thresholds: np.ndarray | None = None, +) -> dict[str, float]: + """Computes the OKS for pose at different thresholds. + + Args: + data: The data for which to compute OKS mAP: a list containing (gt_poses, + predicted_poses) tuples, where gt_pose is an array of shape + (num_gt_individuals, num_bpts, 3) and predicted_poses is an array of shape + (num_predictions, num_bpts, 3). For the GT, the 3 coordinates are (x, y, + visibility) while for the pose they are (x, y, confidence score). + oks_sigma: The OKS sigma to use to compute pose. + oks_bbox_margin: The margin to add around keypoints to compute the area for OKS + computation. + oks_thresholds: The OKS thresholds at which to compute AP. If None, defaults to + (0.5, 0.55, 0.6, ..., 0.9, 0.95). + oks_recall_thresholds: The recall thresholds to use to compute mAP. If None, + defaults to the same default values used in pycocotools. + + Returns: + A dictionary containing mAP and mAR scores. + """ + if oks_thresholds is None: + oks_thresholds = np.linspace(0.5, 0.95, 10) + + if oks_recall_thresholds is None: + oks_recall_thresholds = np.linspace( + start=0.0, + stop=1.00, + num=int(np.round((1.00 - 0.0) / 0.01)) + 1, + endpoint=True, + ) + + total_gt = 0 + pose_data = [] + for gt, pred in data: + # filter data to only keep individuals with at least 2 valid keypoints + gt = gt[np.sum(np.all(~np.isnan(gt), axis=-1), axis=-1) > 1] + pred = pred[np.sum(np.all(~np.isnan(pred), axis=-1), axis=-1) > 1] + + oks_matrix = compute_oks_matrix( + gt[:, :, :2], + pred[:, :, :2], + oks_sigma=oks_sigma, + oks_bbox_margin=oks_bbox_margin, + ) + + total_gt += len(gt) + pose_data.append((gt, pred, oks_matrix)) + + precisions, recalls = [], [] + for oks_threshold in oks_thresholds: + matches = [] + for gt, pred, oks_matrix in pose_data: + image_matches = matching.match_greedy_oks( + gt, + pred, + oks_matrix=oks_matrix, + oks_threshold=oks_threshold, + ) + matches.extend(image_matches) + + if len(matches) == 0: # no predictions -> precision 0, recall 0 + return {"mAP": 0, "mAR": 0} + + scores = np.asarray([m.score for m in matches]) + match_order = np.argsort(-scores, kind="mergesort") + oks_values = np.asarray([m.oks for m in matches]) + oks_values = oks_values[match_order] + + tp = np.cumsum(oks_values >= oks_threshold) + fp = np.cumsum(oks_values < oks_threshold) + rc = tp / total_gt + pr = tp / (fp + tp + np.spacing(1)) + recall = rc[-1] + + # Guarantee precision decreases monotonically, see + # https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173 + for i in range(len(pr) - 1, 0, -1): + if pr[i] > pr[i - 1]: + pr[i - 1] = pr[i] + + inds_rc = np.searchsorted(rc, oks_recall_thresholds, side="left") + precision = np.zeros(inds_rc.shape) + valid = inds_rc < len(pr) + precision[valid] = pr[inds_rc[valid]] + + precisions.append(precision) + recalls.append(recall) + + precisions = np.asarray(precisions) + recalls = np.asarray(recalls) + return { + "mAP": 100 * precisions.mean().item(), + "mAR": 100 * recalls.mean().item(), + } + + +def match_predictions_for_rmse( + data: list[tuple[np.ndarray, np.ndarray]], + single_animal: bool, + oks_bbox_margin: float = 0.0, +) -> list[matching.PotentialMatch]: + """Matches GT keypoints to predictions to compute RMSE. + + Single animal RMSE is computed by simply calculating the distance between each + ground truth keypoint and the corresponding prediction. + + Multi-animal RMSE is computed differently: predictions are first matched to ground + truth individuals using greedy OKS matching. RMSE is then computed only between + predictions and the ground truth pose they are matched to, only when the OKS is + non-zero (greater than a small threshold). Predictions that cannot be matched to + any ground truth with non-zero OKS are not used to compute RMSE. + + Args: + data: The data for which to compute RMSE. This is a list containing (gt_poses, + predicted_poses), where gt_pose is an array of shape (num_gt_individuals, + num_bpts, 3) and predicted_poses is an array of shape (num_predictions, + num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for + the pose they are (x, y, confidence score). + single_animal: Whether this is a single animal dataset. + oks_bbox_margin: When single_animal is False, predictions are matched to GT + using OKS. This is the margin used to apply when computing the bbox from + the pose to compute OKS. + + Returns: + A list containing the predictions matched to ground truth. + + Raises: + ValueError: If `single_animal=True` but more than one ground truth/predicted + keypoint is found for an entry + """ + matches = [] + for gt, pred in data: + if single_animal: + if gt.shape[0] > 1 or pred.shape[0] > 1: + raise ValueError( + "At most 1 individual and 1 prediction can be given when computing " + f"single animal RMSE. Found gt={gt.shape}, pred={pred.shape}" + ) + + image_matches = [] + if gt.shape[0] == 1 and pred.shape[0] == 1: + match = matching.PotentialMatch.from_pose(pred[0]) + match.match(gt[0], oks=float("nan")) # OKS not needed for RMSE + image_matches.append(match) + else: + oks_matrix = compute_oks_matrix( + gt[:, :, :2], + pred[:, :, :2], + oks_sigma=0.1, + oks_bbox_margin=oks_bbox_margin, + ) + image_matches = matching.match_greedy_oks( + gt, + pred, + oks_matrix=oks_matrix, + oks_threshold=1e-6, + ) + + matches.extend(image_matches) + + return matches + + +def compute_rmse( + data: list[tuple[np.ndarray, np.ndarray]], + single_animal: bool, + pcutoff: float | list[float], + data_unique: list[tuple[np.ndarray, np.ndarray]] | None = None, + per_keypoint_results: bool = False, + oks_bbox_margin: float = 0.0, +) -> dict[str, float]: + """Computes the RMSE for pose predictions. + + Single animal RMSE is computed by simply calculating the distance between each + ground truth keypoint and the corresponding prediction. + + Multi-animal RMSE is computed differently: predictions are first matched to ground + truth individuals using greedy OKS matching. RMSE is then computed only between + predictions and the ground truth pose they are matched to, only when the OKS is + non-zero (greater than a small threshold). Predictions that cannot be matched to + any ground truth with non-zero OKS are not used to compute RMSE. + + Args: + data: The data for which to compute RMSE. This is a list containing (gt_poses, + predicted_poses), where gt_pose is an array of shape (num_gt_individuals, + num_bpts, 3) and predicted_poses is an array of shape (num_predictions, + num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for + the pose they are (x, y, confidence score). + single_animal: Whether this is a single animal dataset. + pcutoff: The p-cutoff to use to compute RMSE. If a list, the cutoff for each + bodypart is set individually. The list must have length num_bodyparts + + num_unique_bodyparts. + data_unique: Unique bodypart ground truth and predictions to include in RMSE + computations, if there are any such bodyparts. + per_keypoint_results: Whether to compute the RMSE for each individual keypoint. + oks_bbox_margin: When single_animal is False, predictions are matched to GT + using OKS. This is the margin used to apply when computing the bbox from + the pose to compute OKS. + + Returns: + A dictionary matching metric names to values. It will at least have "rmse" and + "rmse_cutoff" keys. If `per_keypoint_results=True` and there is at least one + non-NaN pixel error it will also contain "rmse_keypoint_X" and + "rmse_cutoff_keypoint_X" keys for each bodypart, where X is the index of the + bodypart. + + Raises: + ValueError: If `single_animal=True` but more than one ground truth/predicted + keypoint is found for an entry + """ + matches = match_predictions_for_rmse(data, single_animal, oks_bbox_margin) + pixel_errors, keypoint_scores = None, None + if len(matches) > 0: + pixel_errors = np.stack([m.pixel_errors() for m in matches]) + keypoint_scores = np.stack([m.keypoint_scores() for m in matches]) + + error, support, cutoff_error, cutoff_support = 0, 0, 0, 0 + if pixel_errors is not None: + bpt_cutoffs = pcutoff + if not isinstance(pcutoff, (int, float)): + bpt_cutoffs = pcutoff[:pixel_errors.shape[1]] + + error, support, cutoff_error, cutoff_support = collect_pixel_errors( + pixel_errors, keypoint_scores, bpt_cutoffs, + ) + + unique_pixel_errors, unique_keypoint_scores = None, None + if data_unique is not None: + u_matches = match_predictions_for_rmse(data_unique, single_animal=True) + if len(u_matches) > 0: + unique_pixel_errors = np.stack([m.pixel_errors() for m in u_matches]) + unique_keypoint_scores = np.stack([m.keypoint_scores() for m in u_matches]) + + bpt_cutoffs = pcutoff + if not isinstance(pcutoff, (int, float)): + bpt_cutoffs = pcutoff[-unique_pixel_errors.shape[1]:] + u_error, u_support, u_cutoff_error, u_cutoff_support = collect_pixel_errors( + unique_pixel_errors, unique_keypoint_scores, bpt_cutoffs, + ) + error += u_error + support += u_support + cutoff_error += u_cutoff_error + cutoff_support += u_cutoff_support + + results = dict(rmse=float("nan"), rmse_pcutoff=float("nan")) + if support > 0: + results["rmse"] = float(error / support) + if cutoff_support > 0: + results["rmse_pcutoff"] = float(cutoff_error / cutoff_support) + + if per_keypoint_results: + bodypart_errors = [("rmse_keypoint", pixel_errors)] + if unique_pixel_errors is not None: + bodypart_errors.append(("rmse_unique_keypoint", unique_pixel_errors)) + + for key_prefix, bpt_errors in bodypart_errors: + for idx, keypoint_error in enumerate(bpt_errors.T): + rmse = float("nan") + if np.any(~np.isnan(keypoint_error)): + rmse = np.nanmean(keypoint_error).item() + results[f"{key_prefix}_{idx}"] = float(rmse) + + return results + + +def compute_detection_rmse( + data: list[tuple[np.ndarray, np.ndarray]], + pcutoff: float | list[float], + data_unique: list[tuple[np.ndarray, np.ndarray]] | None = None, +) -> tuple[float, float]: + """Computes the detection RMSE for pose predictions. + + The detection RMSE score does not take individual assemblies into account. It only + judges the performance of the detections, matching each predicted keypoint to the + closest ground truth for each bodypart. + + This is the same way multi-animal RMSE was computed in DeepLabCut 2.X. + + Args: + data: The data for which to compute RMSE. This is a list containing (gt_poses, + predicted_poses), where gt_pose is an array of shape (num_gt_individuals, + num_bpts, 3) and predicted_poses is an array of shape (num_predictions, + num_bpts, 3). For the GT, the 3 coordinates are (x, y, visibility) while for + the pose they are (x, y, confidence score). + pcutoff: The p-cutoff to use to compute RMSE. If a list, the cutoff for each + bodypart is set individually. The list must have length num_bodyparts + + num_unique_bodyparts. + data_unique: Unique bodypart ground truth and predictions to include in RMSE + computations, if there are any such bodyparts. + + Returns: + The detection RMSE and detection RMSE after removing all detections with a + score below the pcutoff. + """ + distances = [] + distances_cutoff = [] + for image_gt, image_pred in data: + image_gt = image_gt.transpose((1, 0, 2)) # to (num_bpts, num_gt_individuals, 3) + image_pred = image_pred.transpose((1, 0, 2)) # to (num_bpts, num_pred, 3) + + for bpt_index, (bpt_gt, bpt_pred) in enumerate(zip(image_gt, image_pred)): + # filter NaNs and invalid values + bpt_gt = bpt_gt[~np.any(np.isnan(bpt_gt), axis=1)] + bpt_pred = bpt_pred[~np.any(np.isnan(bpt_pred), axis=1)] + if len(bpt_gt) == 0 or len(bpt_pred) == 0: + continue + + if isinstance(pcutoff, (int, float)): + bpt_pcutoff = pcutoff + else: + bpt_pcutoff = pcutoff[bpt_index] + + # assignment of predicted bodyparts to ground truth + neighbors = find_closest_neighbors(bpt_gt, bpt_pred, k=3) + for gt_index, pred_index in enumerate(neighbors): + if pred_index != -1: + gt = bpt_gt[gt_index] + pred = bpt_pred[pred_index] + dist = np.linalg.norm(gt[:2] - pred[:2]) + distances.append(dist) + + score = bpt_pred[pred_index, 2] + if score >= bpt_pcutoff: + distances_cutoff.append(dist) + + if data_unique is not None: + for image_gt, image_pred in data_unique: + assert len(image_gt) <= 1 and len(image_pred) <= 1, ( + f"Unique GT an predictions must have length 0 or 1! Found {image_gt.shape}, " + f"{image_pred.shape}." + ) + + if len(image_gt) == 1 and len(image_pred) == 1: + unique_gt, unique_pred = image_gt[0], image_pred[0] + num_unique = unique_gt.shape[0] + unique_cutoffs = pcutoff + if not isinstance(pcutoff, (int, float)): + unique_cutoffs = pcutoff[-num_unique:] + + for bpt_index, (gt, pred) in enumerate(zip(unique_gt, unique_pred)): + dist = np.linalg.norm(gt[:2] - pred[:2]) + distances.append(dist) + + score = pred[2] + if isinstance(pcutoff, (int, float)): + bpt_pcutoff = unique_cutoffs + else: + bpt_pcutoff = unique_cutoffs[bpt_index] + + if score >= bpt_pcutoff: + distances_cutoff.append(dist) + + rmse, rmse_cutoff = float("nan"), float("nan") + if len(distances) == 0: + return rmse, rmse_cutoff + + distances = np.stack(distances) + if np.any(~np.isnan(distances)): + rmse = float(np.nanmean(distances).item()) + + if len(distances_cutoff) > 0: + distances_cutoff = np.stack(distances_cutoff) + if np.any(~np.isnan(distances_cutoff)): + rmse_cutoff = float(np.nanmean(distances_cutoff).item()) + + return rmse, rmse_cutoff + + +def collect_pixel_errors( + pixel_errors: np.ndarray, + keypoint_scores: np.ndarray, + pcutoff: float, +) -> tuple[float, int, float, int]: + """Collects pixel errors for RMSE computation + + Args: + pixel_errors: The pixel errors to collect, of shape (num_matches, num_bodyparts) + keypoint_scores: The scores corresponding to the pixel errors, of shape + (num_matches, num_bodyparts). + pcutoff: The pcutoff to use when computing cutoff RMSE. + + Returns: error, support, cutoff_error, support_cutoff + error: The sum of all pixel errors. + support: The number of valid pixel errors. + cutoff_error: The sum of all pixel errors with score > pcutoff. + support_cutoff: The number of valid pixel errors with score > pcutoff. + """ + error = 0.0 + cutoff_error = 0.0 + support = np.sum(~np.isnan(pixel_errors)).item() + support_cutoff = 0 + if support > 0: + error += np.nansum(pixel_errors).item() + + cutoff_mask = keypoint_scores >= pcutoff + cutoff_pixel_errors = pixel_errors[cutoff_mask] + support_cutoff = np.sum(~np.isnan(cutoff_pixel_errors)).item() + if support_cutoff > 0: + cutoff_error = np.nansum(cutoff_pixel_errors).item() + + return error, support, cutoff_error, support_cutoff diff --git a/deeplabcut/core/metrics/identity.py b/deeplabcut/core/metrics/identity.py new file mode 100644 index 0000000000..1720bfdffa --- /dev/null +++ b/deeplabcut/core/metrics/identity.py @@ -0,0 +1,92 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Implementations of methods to compute identity prediction accuracy""" +from __future__ import annotations + +import numpy as np +from sklearn.metrics import accuracy_score + +from deeplabcut.core.crossvalutils import find_closest_neighbors + + +def compute_identity_scores( + individuals: list[str], + bodyparts: list[str], + predictions: dict[str, np.ndarray], + identity_scores: dict[str, np.ndarray], + ground_truth: dict[str, np.ndarray], +) -> dict[str, float]: + """ + FIXME: With DLCRNet all heatmap "peaks" above 0.01 were kept, with 1 keypoint and + 1 identity score map per peak. Then, for each ground truth keypoint, we selected + the prediction closest to it, and evaluated the identity score in that position. + This is no longer the case, as we're now evaluating after assembly. So we only + have num_individuals assemblies. + + Args: + individuals: + bodyparts: + predictions: (num_assemblies, num_bodyparts, 3) + identity_scores: (num_assemblies, num_bodyparts, num_individuals) + ground_truth: (num_individuals, num_bodyparts, 3) + + Returns: + + """ + if not len(predictions) == len(ground_truth): + raise ValueError("Mismatch between number of predictions and ground truth") + + all_bpts = np.asarray(len(individuals) * bodyparts) + ids = np.full((len(predictions), len(all_bpts), 2), np.nan) + for i, (image, pred) in enumerate(predictions.items()): + for j in range(len(individuals)): + for k in range(len(bodyparts)): + bpt_idx = len(bodyparts) * j + k + ids[i, bpt_idx, 0] = j + + # set keypoints that aren't visible to NaN + gt = ground_truth[image].copy() + gt[gt[..., 2] <= 0, :2] = np.nan + gt = gt[..., :2] + + id_scores = identity_scores[image] + + # reorder to (bodypart, individual, ...) + gt = gt.transpose((1, 0, 2)) + pred = pred.transpose((1, 0, 2))[..., :2] + id_scores = id_scores.transpose((1, 0, 2)) + for bpt, bpt_gt, bpt_pred, bpt_id_scores in zip(bodyparts, gt, pred, id_scores): + # assign ground truth keypoints to the closest prediction, so the ID score + # is the closest possible to the ID score computed with "ground truth" + indices_gt = np.flatnonzero(np.all(~np.isnan(bpt_gt), axis=1)) + + # Remove NaN predictions from the bodypart predictions + indices_pred = np.all(np.isfinite(bpt_pred), axis=1) + bpt_pred = bpt_pred[indices_pred] + bpt_id_scores = bpt_id_scores[indices_pred] + + neighbors = find_closest_neighbors(bpt_gt[indices_gt], bpt_pred, k=3) + found = neighbors != -1 + indices = np.flatnonzero(all_bpts == bpt) + # Get the predicted identity of each bodypart by taking the argmax + ids[i, indices[indices_gt[found]], 1] = np.argmax( + bpt_id_scores[neighbors[found]], axis=1 + ) + + ids = ids.reshape((len(predictions), len(individuals), len(bodyparts), 2)) + results = {} + for i, bpt in enumerate(bodyparts): + temp = ids[:, :, i].reshape((-1, 2)) + valid = np.isfinite(temp).all(axis=1) + y_true, y_pred = temp[valid].T + results[f"{bpt}_accuracy"] = accuracy_score(y_true, y_pred) + + return results diff --git a/deeplabcut/core/metrics/matching.py b/deeplabcut/core/metrics/matching.py new file mode 100644 index 0000000000..95b28ebe5b --- /dev/null +++ b/deeplabcut/core/metrics/matching.py @@ -0,0 +1,169 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Algorithms to match predictions to ground truth labels""" +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np + + +@dataclass +class PotentialMatch: + """A potential match between predicted pose and ground truth pose. + + Args: + pose: An array of shape (num_bodyparts, 3) + score: The score for the prediction. This could be the mean of the confidence + score for each bodypart, or another value representing how confident the + model is that this assembly is correct. + gt: None if no ground truth pose was matched to the prediction. If defined, the + ground truth to which the prediction is matched. It should be of shape + (num_bodyparts, 3), where the 3 values are x, y and visibility. + oks: The OKS score between the pose and the ground truth. + """ + + pose: np.ndarray + score: float + gt: np.ndarray | None = None + oks: float = 0.0 + + def keypoint_scores(self) -> np.ndarray: + """Returns: The confidence score for each bodypart in the predicted pose.""" + return self.pose[:, 2].copy() + + def pixel_errors(self) -> np.ndarray: + """ + Returns: + The distance (in pixels) between each predicted and ground truth bodypart. + If this prediction is unmatched, returns an array of length num_bodyparts + containing all NaNs. + """ + if self.gt is None: + return np.full(len(self.pose), np.nan) + + return np.linalg.norm(self.pose[:, :2] - self.gt[:, :2], axis=1) + + def match(self, gt: np.ndarray, oks: float) -> None: + """Adds a ground truth match to this PotentialMatch + + Args: + gt: The ground truth to which the prediction is matched. The ground truth + pose should be of shape (num_bodyparts, 3), where the 3 values are x, y + and visibility. + oks: The OKS similarity between the ground truth and this. + """ + self.gt = gt + self.oks = oks + + @classmethod + def from_pose(cls, pose: np.ndarray) -> "PotentialMatch": + assert len(pose.shape) == 2 # Must be pose for a single individual + scores = pose[:, 2] + if np.all(np.isnan(scores)): + raise ValueError( + "Cannot create a Match from a pose prediction where all scores are nan " + f"(pose={pose})" + ) + + return PotentialMatch(pose=pose, score=np.nanmean(scores).item()) + + +def match_greedy_oks( + ground_truth: np.ndarray, + predictions: np.ndarray, + oks_matrix: np.ndarray, + oks_threshold: float = 0.0, +) -> list[PotentialMatch]: + """Greedy matching of ground truth individuals to predicted individuals using OKS + + This is done in the same way as done in pycocotools. The predictions must be sorted + by score before being passed to this function. + + Args: + ground_truth: The ground truth labels for an image, of shape (n_idv, n_bpt, 2) + predictions: The predictions for an image, of shape (n_idv, n_bpt, 2) + oks_matrix: A matrix of shape (n_pred, n_kpts) where entry (i, j) is the OKS + between prediction i and ground truth j. + oks_threshold: The min. OKS for a prediction to be matched to a GT pose + + Returns: + A list containing a PotentialMatch for each predicted pose in the given + predictions. + """ + matches = [PotentialMatch.from_pose(pose=pred) for pred in predictions] + matched_gt_indices = set() + for idx, pred in enumerate(predictions): + oks = oks_matrix[idx] + if np.all(np.isnan(oks)): + continue + + ind_best = np.nanargmax(oks) + + # if this gt already matched, continue + if ind_best in matched_gt_indices: + continue + + # Only match the pred to the GT if the OKS value is above a given threshold + if oks[ind_best] < oks_threshold: + continue + + matched_gt_indices.add(ind_best) + matches[idx].match(gt=ground_truth[ind_best], oks=oks[ind_best]) + + return matches + + +def match_greedy_rmse( + ground_truth: np.ndarray, + predictions: np.ndarray, + keep_assemblies: bool = True, +) -> list[PotentialMatch]: + """Greedy matching of ground truth individuals to predicted individuals using RMSE + + The predictions must be sorted by score before being passed to this function. + + Args: + ground_truth: The ground truth labels for an image, of shape (n_idv, n_bpt, 2) + predictions: The predictions for an image, of shape (n_idv, n_bpt, 2) + keep_assemblies: Whether to match predicted keypoints to ground truth keypoints + while enforcing that all bodyparts for a predicted individual are matched + to bodyparts from the same ground truth assembly. When set to False, this + corresponds to detection RMSE score. + + Returns: + A list containing a PotentialMatch for each predicted pose in the given + predictions. + """ + if not keep_assemblies: + raise NotImplementedError() + + matches = [PotentialMatch.from_pose(pose=pred) for pred in predictions] + matched_gt_indices = set() + for idx, pred in enumerate(predictions): + bpt_distances = np.linalg.norm(pred[:, :2] - ground_truth[:, :, :2], axis=-1) + if np.all(np.isnan(bpt_distances)): + continue + + distances = np.nanmean(bpt_distances, axis=-1) + ind_best = np.nanargmin(distances) + + # if this gt already matched, continue + if ind_best in matched_gt_indices: + continue + + matched_gt_indices.add(ind_best) + matches[idx].match( + gt=ground_truth[ind_best], + oks=float("nan"), # don't compute OKS here + ) + + return matches diff --git a/deeplabcut/core/trackingutils.py b/deeplabcut/core/trackingutils.py new file mode 100644 index 0000000000..565582dcd1 --- /dev/null +++ b/deeplabcut/core/trackingutils.py @@ -0,0 +1,840 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + +import abc +import math +import warnings +from collections import defaultdict + +import numpy as np +from filterpy.common import kinematic_kf +from filterpy.kalman import KalmanFilter +from matplotlib import patches +from numba import jit +from numba.core.errors import NumbaPerformanceWarning +from scipy.optimize import linear_sum_assignment +from scipy.stats import mode +from tqdm import tqdm + +warnings.simplefilter("ignore", category=NumbaPerformanceWarning) + +TRACK_METHODS = { + "box": "_bx", + "skeleton": "_sk", + "ellipse": "_el", + "transformer": "_tr", +} + + +def calc_iou(bbox1, bbox2): + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + w = max(0, x2 - x1) + h = max(0, y2 - y1) + wh = w * h + return wh / ( + (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + + (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + - wh + ) + + +class BaseTracker: + """Base class for a constant-velocity Kalman filter-based tracker.""" + + n_trackers = 0 + + def __init__(self, dim, dim_z): + self.kf = kinematic_kf( + dim, + 1, + dim_z=dim_z, + order_by_dim=False, + ) + self.id = self.__class__.n_trackers + self.__class__.n_trackers += 1 + self.time_since_update = 0 + self.age = 0 + self.hits = 0 + self.hit_streak = 0 + + def update(self, z): + self.time_since_update = 0 + self.hits += 1 + self.hit_streak += 1 + self.kf.update(z) + + def predict(self): + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + return self.state + + @property + def state(self): + return self.kf.x.squeeze()[: self.kf.dim_z] + + @state.setter + def state(self, state): + self.kf.x[: self.kf.dim_z] = state + + +class Ellipse: + def __init__(self, x, y, width, height, theta): + self.x = x + self.y = y + self.width = width + self.height = height + self.theta = theta # in radians + self._geometry = None + + @property + def parameters(self): + return self.x, self.y, self.width, self.height, self.theta + + @property + def aspect_ratio(self): + return max(self.width, self.height) / min(self.width, self.height) + + def calc_similarity_with(self, other_ellipse): + max_dist = max( + self.height, self.width, other_ellipse.height, other_ellipse.width + ) + dist = math.sqrt( + (self.x - other_ellipse.x) ** 2 + (self.y - other_ellipse.y) ** 2 + ) + + if max_dist == 0: + max_dist = 1 + + cost1 = 1 - min(dist / max_dist, 1) + cost2 = abs(math.cos(self.theta - other_ellipse.theta)) + return 0.8 * cost1 + 0.2 * cost2 * cost1 + + def contains_points(self, xy, tol=0.1): + ca = math.cos(self.theta) + sa = math.sin(self.theta) + x_demean = xy[:, 0] - self.x + y_demean = xy[:, 1] - self.y + return ( + ((ca * x_demean + sa * y_demean) ** 2 / (0.5 * self.width) ** 2) + + ((sa * x_demean - ca * y_demean) ** 2 / (0.5 * self.height) ** 2) + ) <= 1 + tol + + def draw(self, show_axes=True, ax=None, **kwargs): + import matplotlib.pyplot as plt + from matplotlib.lines import Line2D + from matplotlib.transforms import Affine2D + + if ax is None: + ax = plt.subplot(111, aspect="equal") + el = patches.Ellipse( + xy=(self.x, self.y), + width=self.width, + height=self.height, + angle=np.rad2deg(self.theta), + **kwargs, + ) + ax.add_patch(el) + if show_axes: + major = Line2D([-self.width / 2, self.width / 2], [0, 0], lw=3, zorder=3) + minor = Line2D([0, 0], [-self.height / 2, self.height / 2], lw=3, zorder=3) + trans = ( + Affine2D().rotate(self.theta).translate(self.x, self.y) + ax.transData + ) + major.set_transform(trans) + minor.set_transform(trans) + ax.add_artist(major) + ax.add_artist(minor) + + +class EllipseFitter: + def __init__(self, sd=2): + self.sd = sd + self.x = None + self.y = None + self.params = None + self._coeffs = None + + def fit(self, xy): + self.x, self.y = xy[np.isfinite(xy).all(axis=1)].T + if len(self.x) < 3: + return None + if self.sd: + self.params = self._fit_error(self.x, self.y, self.sd) + else: + self._coeffs = self._fit(self.x, self.y) + self.params = self.calc_parameters(self._coeffs) + if not np.isnan(self.params).any(): + el = Ellipse(*self.params) + # Regularize by forcing AR <= 5 + # max_ar = 5 + # if el.aspect_ratio >= max_ar: + # if el.height > el.width: + # el.width = el.height / max_ar + # else: + # el.height = el.width / max_ar + # Orient the ellipse such that it encompasses most points + # n_inside = el.contains_points(np.c_[self.x, self.y]).sum() + # el.theta += 0.5 * np.pi + # if el.contains_points(np.c_[self.x, self.y]).sum() < n_inside: + # el.theta -= 0.5 * np.pi + return el + return None + + @staticmethod + @jit(nopython=True) + def _fit(x, y): + """ + Least Squares ellipse fitting algorithm + Fit an ellipse to a set of X- and Y-coordinates. + See Halir and Flusser, 1998 for implementation details + + :param x: ndarray, 1D trajectory + :param y: ndarray, 1D trajectory + :return: 1D ndarray of 6 coefficients of the general quadratic curve: + ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0 + """ + D1 = np.vstack((x * x, x * y, y * y)) + D2 = np.vstack((x, y, np.ones_like(x))) + S1 = D1 @ D1.T + S2 = D1 @ D2.T + S3 = D2 @ D2.T + T = -np.linalg.inv(S3) @ S2.T + temp = S1 + S2 @ T + M = np.zeros_like(temp) + M[0] = temp[2] * 0.5 + M[1] = -temp[1] + M[2] = temp[0] * 0.5 + E, V = np.linalg.eig(M) + cond = 4 * V[0] * V[2] - V[1] ** 2 + a1 = V[:, cond > 0][:, 0] + a2 = T @ a1 + return np.hstack((a1, a2)) + + @staticmethod + @jit(nopython=True) + def _fit_error(x, y, sd): + """ + Fit a sd-sigma covariance error ellipse to the data. + + :param x: ndarray, 1D input of X coordinates + :param y: ndarray, 1D input of Y coordinates + :param sd: int, size of the error ellipse in 'standard deviation' + :return: ellipse center, semi-axes length, angle to the X-axis + """ + cov = np.cov(x, y) + E, V = np.linalg.eigh(cov) # Returns the eigenvalues in ascending order + # r2 = chi2.ppf(2 * norm.cdf(sd) - 1, 2) + # height, width = np.sqrt(E * r2) + height, width = 2 * sd * np.sqrt(E) + a, b = V[:, 1] + rotation = math.atan2(b, a) % np.pi + return [np.mean(x), np.mean(y), width, height, rotation] + + @staticmethod + @jit(nopython=True) + def calc_parameters(coeffs): + """ + Calculate ellipse center coordinates, semi-axes lengths, and + the counterclockwise angle of rotation from the x-axis to the ellipse major axis. + Visit http://mathworld.wolfram.com/Ellipse.html + for how to estimate ellipse parameters. + + :param coeffs: list of fitting coefficients + :return: center: 1D ndarray, semi-axes: 1D ndarray, angle: float + """ + # The general quadratic curve has the form: + # ax^2 + 2bxy + cy^2 + 2dx + 2fy + g = 0 + a, b, c, d, f, g = coeffs + b *= 0.5 + d *= 0.5 + f *= 0.5 + + # Ellipse center coordinates + x0 = (c * d - b * f) / (b * b - a * c) + y0 = (a * f - b * d) / (b * b - a * c) + + # Semi-axes lengths + num = 2 * (a * f * f + c * d * d + g * b * b - 2 * b * d * f - a * c * g) + den1 = (b * b - a * c) * (np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c)) + den2 = (b * b - a * c) * (-np.sqrt((a - c) ** 2 + 4 * b * b) - (a + c)) + major = np.sqrt(num / den1) + minor = np.sqrt(num / den2) + + # Angle to the horizontal + if b == 0: + if a < c: + phi = 0 + else: + phi = np.pi / 2 + else: + if a < c: + phi = np.arctan(2 * b / (a - c)) / 2 + else: + phi = np.pi / 2 + np.arctan(2 * b / (a - c)) / 2 + + return [x0, y0, 2 * major, 2 * minor, phi] + + +class EllipseTracker(BaseTracker): + def __init__(self, params): + super().__init__(dim=5, dim_z=5) + self.kf.R[2:, 2:] *= 10.0 + # High uncertainty to the unobservable initial velocities + self.kf.P[5:, 5:] *= 1000.0 + self.kf.P *= 10.0 + self.kf.Q[5:, 5:] *= 0.01 + self.state = params + + @BaseTracker.state.setter + def state(self, params): + state = np.asarray(params).reshape((-1, 1)) + super(EllipseTracker, type(self)).state.fset(self, state) + + +class SkeletonTracker(BaseTracker): + def __init__(self, n_bodyparts): + super().__init__(dim=n_bodyparts * 2, dim_z=n_bodyparts) + self.kf.Q[self.kf.dim_z :, self.kf.dim_z :] *= 10 + self.kf.R[self.kf.dim_z :, self.kf.dim_z :] *= 0.01 + self.kf.P[self.kf.dim_z :, self.kf.dim_z :] *= 1000 + + def update(self, pose): + flat = pose.reshape((-1, 1)) + empty = np.isnan(flat).squeeze() + if empty.any(): + H = self.kf.H.copy() + H[empty] = 0 + flat[empty] = 0 + self.kf.update(flat, H=H) + else: + super().update(flat) + + @BaseTracker.state.setter + def state(self, pose): + curr_pose = pose.copy() + empty = np.isnan(curr_pose).all(axis=1) + if empty.any(): + fill = np.nanmean(pose, axis=0) + curr_pose[empty] = fill + super(SkeletonTracker, type(self)).state.fset(self, curr_pose.reshape((-1, 1))) + + +class BoxTracker(BaseTracker): + def __init__(self, bbox): + super().__init__(dim=4, dim_z=4) + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array( + [ + [1, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1], + ] + ) + self.kf.H = np.array( + [ + [1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + ] + ) + self.kf.R[2:, 2:] *= 10.0 + # Give high uncertainty to the unobservable initial velocities + self.kf.P[4:, 4:] *= 1000.0 + self.kf.P *= 10.0 + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + self.state = bbox + + def update(self, bbox): + super().update(self.convert_bbox_to_z(bbox)) + + def predict(self): + if (self.kf.x[6] + self.kf.x[2]) <= 0: + self.kf.x[6] *= 0.0 + return super().predict() + + @property + def state(self): + return self.convert_x_to_bbox(self.kf.x)[0] + + @state.setter + def state(self, bbox): + state = self.convert_bbox_to_z(bbox) + super(BoxTracker, type(self)).state.fset(self, state) + + @staticmethod + def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if score is None: + return np.array( + [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0] + ).reshape((1, 4)) + else: + return np.array( + [x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score] + ).reshape((1, 5)) + + @staticmethod + def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2.0 + y = bbox[1] + h / 2.0 + s = w * h # scale is just area + r = w / float(h) + return np.array([x, y, s, r]).reshape((4, 1)) + + +class SORTBase(metaclass=abc.ABCMeta): + def __init__(self): + self.n_frames = 0 + self.trackers = [] + + @abc.abstractmethod + def track(self): + pass + + +class SORTEllipse(SORTBase): + def __init__(self, max_age, min_hits, iou_threshold, sd=2): + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.fitter = EllipseFitter(sd) + EllipseTracker.n_trackers = 0 + super().__init__() + + def track(self, poses, identities=None): + self.n_frames += 1 + + trackers = np.zeros((len(self.trackers), 6)) + for i in range(len(trackers)): + trackers[i, :5] = self.trackers[i].predict() + empty = np.isnan(trackers).any(axis=1) + trackers = trackers[~empty] + for ind in np.flatnonzero(empty)[::-1]: + self.trackers.pop(ind) + + ellipses = [] + pred_ids = [] + for i, pose in enumerate(poses): + el = self.fitter.fit(pose) + if el is not None: + ellipses.append(el) + if identities is not None: + pred_ids.append(mode(identities[i])[0][0]) + if not len(trackers): + matches = np.empty((0, 2), dtype=int) + unmatched_detections = np.arange(len(ellipses)) + unmatched_trackers = np.empty((0, 6), dtype=int) + else: + ellipses_trackers = [Ellipse(*t[:5]) for t in trackers] + cost_matrix = np.zeros((len(ellipses), len(ellipses_trackers))) + for i, el in enumerate(ellipses): + for j, el_track in enumerate(ellipses_trackers): + cost = el.calc_similarity_with(el_track) + if identities is not None: + match = 2 if pred_ids[i] == self.trackers[j].id_ else 1 + cost *= match + cost_matrix[i, j] = cost + row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=True) + unmatched_detections = [ + i for i, _ in enumerate(ellipses) if i not in row_indices + ] + unmatched_trackers = [ + j for j, _ in enumerate(trackers) if j not in col_indices + ] + matches = [] + for row, col in zip(row_indices, col_indices): + val = cost_matrix[row, col] + # diff = val - cost_matrix + # diff[row, col] += val + # if ( + # val < self.iou_threshold + # or np.any(diff[row] <= 0.2) + # or np.any(diff[:, col] <= 0.2) + # ): + if val < self.iou_threshold: + unmatched_detections.append(row) + unmatched_trackers.append(col) + else: + matches.append([row, col]) + if not len(matches): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.stack(matches) + unmatched_trackers = np.asarray(unmatched_trackers) + unmatched_detections = np.asarray(unmatched_detections) + + animalindex = [] + for t, tracker in enumerate(self.trackers): + if t not in unmatched_trackers: + ind = matches[matches[:, 1] == t, 0][0] + animalindex.append(ind) + tracker.update(ellipses[ind].parameters) + else: + animalindex.append(-1) + + for i in unmatched_detections: + trk = EllipseTracker(ellipses[i].parameters) + if identities is not None: + trk.id_ = mode(identities[i])[0][0] + self.trackers.append(trk) + animalindex.append(i) + + i = len(self.trackers) + ret = [] + for trk in reversed(self.trackers): + d = trk.state + if (trk.time_since_update < 1) and ( + trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits + ): + ret.append( + np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape( + 1, -1 + ) + ) # for DLC we also return the original animalid + # +1 as MOT benchmark requires positive >> this is removed for DLC! + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.trackers.pop(i) + + if len(ret) > 0: + return np.concatenate(ret) + return np.empty((0, 7)) + + +class SORTSkeleton(SORTBase): + def __init__(self, n_bodyparts, max_age=20, min_hits=3, oks_threshold=0.5): + self.n_bodyparts = n_bodyparts + self.max_age = max_age + self.min_hits = min_hits + self.oks_threshold = oks_threshold + SkeletonTracker.n_trackers = 0 + super().__init__() + + @staticmethod + def weighted_hausdorff(x, y): + # Modified from scipy source code: + # - to restrict its use to 2D + # - to get rid of shuffling (since arrays are only (nbodyparts * 3) element long) + # TODO - factor in keypoint confidence (and weight by # of observations??) + cmax = 0 + for i in range(x.shape[0]): + no_break_occurred = True + cmin = np.inf + for j in range(y.shape[0]): + d = (x[i, 0] - y[j, 0]) ** 2 + (x[i, 1] - y[j, 1]) ** 2 + if d < cmax: + no_break_occurred = False + break + if d < cmin: + cmin = d + if cmin != np.inf and cmin > cmax and no_break_occurred: + cmax = cmin + return np.sqrt(cmax) + + @staticmethod + def object_keypoint_similarity(x, y): + mask = ~np.isnan(x * y).all(axis=1) # Intersection visible keypoints + xx = x[mask] + yy = y[mask] + dist = np.linalg.norm(xx - yy, axis=1) + scale = np.sqrt( + np.product(np.ptp(yy, axis=0)) + ) # square root of bounding box area + oks = np.exp(-0.5 * (dist / (0.05 * scale)) ** 2) + return np.mean(oks) + + def calc_pairwise_hausdorff_dist(self, poses, poses_ref): + mat = np.zeros((len(poses), len(poses_ref))) + for i, pose in enumerate(poses): + for j, pose_ref in enumerate(poses_ref): + mat[i, j] = self.weighted_hausdorff(pose, pose_ref) + return mat + + def calc_pairwise_oks(self, poses, poses_ref): + mat = np.zeros((len(poses), len(poses_ref))) + for i, pose in enumerate(poses): + for j, pose_ref in enumerate(poses_ref): + mat[i, j] = self.object_keypoint_similarity(pose, pose_ref) + return mat + + def track(self, poses): + self.n_frames += 1 + + if not len(self.trackers): + for pose in poses: + tracker = SkeletonTracker(self.n_bodyparts) + tracker.state = pose + self.trackers.append(tracker) + + poses_ref = [] + for i, tracker in enumerate(self.trackers): + pose_ref = tracker.predict() + poses_ref.append(pose_ref.reshape((-1, 2))) + + # mat = self.calc_pairwise_oks(poses, poses_ref) + mat = self.calc_pairwise_hausdorff_dist(poses, poses_ref) + row_indices, col_indices = linear_sum_assignment(mat, maximize=False) + + unmatched_poses = [p for p, _ in enumerate(poses) if p not in row_indices] + unmatched_trackers = [ + t for t, _ in enumerate(poses_ref) if t not in col_indices + ] + # Remove matched detections with low OKS + # matches = [] + # for row, col in zip(row_indices, col_indices): + # if mat[row, col] < self.oks_threshold: + # unmatched_poses.append(row) + # unmatched_trackers.append(col) + # else: + # matches.append([row, col]) + # if not len(matches): + # matches = np.empty((0, 2), dtype=int) + # else: + # matches = np.stack(matches) + matches = np.c_[row_indices, col_indices] + + animalindex = [] + for t, tracker in enumerate(self.trackers): + if t not in unmatched_trackers: + ind = matches[matches[:, 1] == t, 0][0] + animalindex.append(ind) + tracker.update(poses[ind]) + else: + animalindex.append(-1) + + for i in unmatched_poses: + tracker = SkeletonTracker(self.n_bodyparts) + tracker.state = poses[i] + self.trackers.append(tracker) + animalindex.append(i) + + states = [] + i = len(self.trackers) + for tracker in reversed(self.trackers): + i -= 1 + if tracker.time_since_update > self.max_age: + self.trackers.pop() + continue + state = tracker.predict() + states.append(np.r_[state, [tracker.id, int(animalindex[i])]]) + if len(states) > 0: + return np.stack(states) + return np.empty((0, self.n_bodyparts * 2 + 2)) + + +class SORTBox(SORTBase): + def __init__(self, max_age, min_hits, iou_threshold): + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + BoxTracker.n_trackers = 0 + super().__init__() + + def track(self, dets): + self.n_frames += 1 + + trackers = np.zeros((len(self.trackers), 5)) + for i in range(len(trackers)): + trackers[i, :4] = self.trackers[i].predict() + empty = np.isnan(trackers).any(axis=1) + trackers = trackers[~empty] + for ind in np.flatnonzero(empty)[::-1]: + self.trackers.pop(ind) + + matched, unmatched_dets, unmatched_trks = self.match_detections_to_trackers( + dets, trackers, self.iou_threshold + ) + + # update matched trackers with assigned detections + animalindex = [] + for t, trk in enumerate(self.trackers): + if t not in unmatched_trks: + d = matched[np.where(matched[:, 1] == t)[0], 0] + animalindex.append(d[0]) + trk.update(dets[d, :][0]) # update coordinates + else: + animalindex.append("nix") # lost trk! + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = BoxTracker(dets[i, :]) + self.trackers.append(trk) + animalindex.append(i) + + i = len(self.trackers) + ret = [] + for trk in reversed(self.trackers): + d = trk.state + if (trk.time_since_update < 1) and ( + trk.hit_streak >= self.min_hits or self.n_frames <= self.min_hits + ): + ret.append( + np.concatenate((d, [trk.id, int(animalindex[i - 1])])).reshape( + 1, -1 + ) + ) # for DLC we also return the original animalid + # +1 as MOT benchmark requires positive >> this is removed for DLC! + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.trackers.pop(i) + + if len(ret) > 0: + return np.concatenate(ret) + return np.empty((0, 5)) + + @staticmethod + def match_detections_to_trackers(detections, trackers, iou_threshold): + """ + Assigns detections to tracked object (both represented as bounding boxes) + + Returns 3 lists of matches, unmatched_detections and unmatched_trackers + """ + if not len(trackers): + return ( + np.empty((0, 2), dtype=int), + np.arange(len(detections)), + np.empty((0, 5), dtype=int), + ) + iou_matrix = np.zeros((len(detections), len(trackers)), dtype=np.float32) + + for d, det in enumerate(detections): + for t, trk in enumerate(trackers): + iou_matrix[d, t] = calc_iou(det, trk) + row_indices, col_indices = linear_sum_assignment(-iou_matrix) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in row_indices: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in col_indices: + unmatched_trackers.append(t) + + # filter out matched with low IOU + matches = [] + for row, col in zip(row_indices, col_indices): + if iou_matrix[row, col] < iou_threshold: + unmatched_detections.append(row) + unmatched_trackers.append(col) + else: + matches.append([row, col]) + if not len(matches): + matches = np.empty((0, 2), dtype=int) + else: + matches = np.stack(matches) + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) + + +def fill_tracklets(tracklets, trackers, animals, imname): + for content in trackers: + tracklet_id, pred_id = content[-2:].astype(int) + if tracklet_id not in tracklets: + tracklets[tracklet_id] = {} + if pred_id != -1: + tracklets[tracklet_id][imname] = np.asarray(animals[pred_id]) + else: # Resort to the tracker prediction + xy = np.asarray(content[:-2]) + pred = np.insert(xy, range(2, len(xy) + 1, 2), 1) + tracklets[tracklet_id][imname] = np.asarray(pred) + + +def calc_bboxes_from_keypoints(data, slack=0, offset=0): + data = np.asarray(data) + if data.shape[-1] < 3: + raise ValueError("Data should be of shape (n_animals, n_bodyparts, 3)") + + if data.ndim != 3: + data = np.expand_dims(data, axis=0) + bboxes = np.full((data.shape[0], 5), np.nan) + bboxes[:, :2] = np.nanmin(data[..., :2], axis=1) - slack # X1, Y1 + bboxes[:, 2:4] = np.nanmax(data[..., :2], axis=1) + slack # X2, Y2 + bboxes[:, -1] = np.nanmean(data[..., 2]) # Average confidence + bboxes[:, [0, 2]] += offset + return bboxes + + +def reconstruct_all_ellipses(data, sd): + xy = data.droplevel("scorer", axis=1).drop("likelihood", axis=1, level=-1) + if "single" in xy: + xy.drop("single", axis=1, level="individuals", inplace=True) + animals = xy.columns.get_level_values("individuals").unique() + nrows = xy.shape[0] + ellipses = np.full((len(animals), nrows, 5), np.nan) + fitter = EllipseFitter(sd) + for n, animal in enumerate(animals): + data = xy.xs(animal, axis=1, level="individuals").values.reshape((nrows, -1, 2)) + for i, coords in enumerate(tqdm(data)): + el = fitter.fit(coords.astype(np.float64)) + if el is not None: + ellipses[n, i] = el.parameters + return ellipses + + +def _track_individuals( + individuals, min_hits=1, max_age=5, similarity_threshold=0.6, track_method="ellipse" +): + if track_method not in TRACK_METHODS: + raise ValueError(f"Unknown {track_method} tracker.") + + if track_method == "ellipse": + tracker = SORTEllipse(max_age, min_hits, similarity_threshold) + elif track_method == "box": + tracker = SORTBox(max_age, min_hits, similarity_threshold) + else: + n_bodyparts = individuals[0][0].shape[0] + tracker = SORTSkeleton(n_bodyparts, max_age, min_hits, similarity_threshold) + + tracklets = defaultdict(dict) + all_hyps = dict() + for i, (multi, single) in enumerate(tqdm(individuals)): + if single is not None: + tracklets["single"][i] = single + if multi is None: + continue + if track_method == "box": + # TODO: get cropping parameters and utilize! + xy = calc_bboxes_from_keypoints(multi) + else: + xy = multi[..., :2] + hyps = tracker.track(xy) + all_hyps[i] = hyps + for hyp in hyps: + tracklet_id, pred_id = hyp[-2:].astype(int) + if pred_id != -1: + tracklets[tracklet_id][i] = multi[pred_id] + return tracklets, all_hyps diff --git a/deeplabcut/core/visualization.py b/deeplabcut/core/visualization.py new file mode 100644 index 0000000000..1911a4f769 --- /dev/null +++ b/deeplabcut/core/visualization.py @@ -0,0 +1,236 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Visualization methods for """ +from __future__ import annotations + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +def form_figure(nx, ny) -> tuple[plt.Figure, plt.Axes]: + """Forms a figure on which to plot images""" + fig, ax = plt.subplots(frameon=False) + ax.set_xlim(0, nx) + ax.set_ylim(0, ny) + ax.axis("off") + ax.invert_yaxis() + fig.tight_layout() + return fig, ax + + +def visualize_scoremaps( + image: np.ndarray, + scmap: np.ndarray, +) -> tuple[plt.Figure, plt.Axes]: + """Plots scoremaps as an image overlay. + + Args: + image: An image as a numpy array of shape (h, w, channels) + scmap: A scoremap of shape (h, w) + + Returns: + The figure and axis on which the image scoremap was plot. + """ + ny, nx = np.shape(image)[:2] + fig, ax = form_figure(nx, ny) + ax.imshow(image) + ax.imshow(scmap, alpha=0.5) + return fig, ax + + +def visualize_locrefs( + image: np.ndarray, + scmap: np.ndarray, + locref_x: np.ndarray, + locref_y: np.ndarray, + step: int = 5, + zoom_width: int = 0, +) -> tuple[plt.Figure, plt.Axes]: + """Plots a scoremap and the corresponding location refinement field on an image. + + Args: + image: An image as a numpy array of shape (h, w, channels) + scmap: A scoremap of shape (h, w) + locref_x: The x-coordinate of the location refinement field, of shape (h, w) + locref_y: The y-coordinate of the location refinement field, of shape (h, w) + step: The step with which to plot the location refinement field. + zoom_width: The zoom width with which to plot the scoremaps. + + Returns: + The figure and axis on which the image scoremap and locref field were plot. + """ + fig, ax = visualize_scoremaps(image, scmap) + X, Y = np.meshgrid(np.arange(locref_x.shape[1]), np.arange(locref_x.shape[0])) + M = np.zeros(locref_x.shape, dtype=bool) + M[scmap < 0.5] = True + U = np.ma.masked_array(locref_x, mask=M) + V = np.ma.masked_array(locref_y, mask=M) + ax.quiver( + X[::step, ::step], + Y[::step, ::step], + U[::step, ::step], + V[::step, ::step], + color="r", + units="x", + scale_units="xy", + scale=1, + angles="xy", + ) + if zoom_width > 0: + maxloc = np.unravel_index(np.argmax(scmap), scmap.shape) + ax.set_xlim(maxloc[1] - zoom_width, maxloc[1] + zoom_width) + ax.set_ylim(maxloc[0] + zoom_width, maxloc[0] - zoom_width) + return fig, ax + + +def visualize_paf( + image: np.ndarray, + paf: np.ndarray, + step: int = 5, + colors: list | None = None, +) -> tuple[plt.Figure, plt.Axes]: + """Plots the PAF on top of the image. + + Args: + image: Shape (height, width, channels). The image on which the model was run. + paf: Shape (height, width, 2 * len(paf_graph)). The PAF output by the model. + step: The step with which to plot the scoremaps. + colors: The colormap to use. + + Returns: + The figure and axis on which the image PAF was plot. + """ + ny, nx = np.shape(image)[:2] + fig, ax = form_figure(nx, ny) + ax.imshow(image) + n_fields = paf.shape[2] + if colors is None: + colors = ["r"] * n_fields + for n in range(n_fields): + U = paf[:, :, n, 0] + V = paf[:, :, n, 1] + X, Y = np.meshgrid(np.arange(U.shape[1]), np.arange(U.shape[0])) + M = np.zeros(U.shape, dtype=bool) + M[U**2 + V**2 < 0.5 * 0.5**2] = True + U = np.ma.masked_array(U, mask=M) + V = np.ma.masked_array(V, mask=M) + ax.quiver( + X[::step, ::step], + Y[::step, ::step], + U[::step, ::step], + V[::step, ::step], + scale=50, + headaxislength=4, + alpha=1, + width=0.002, + color=colors[n], + angles="xy", + ) + return fig, ax + + +def generate_model_output_plots( + output_folder: Path, + image_name: str, + bodypart_names: list[str], + bodyparts_to_plot: list[str], + image: np.ndarray, + scmap: np.ndarray, + locref: np.ndarray | None = None, + paf: np.ndarray | None = None, + paf_graph: list[tuple[int, int]] | None = None, + paf_all_in_one: bool = True, + paf_colormap: str = "rainbow", + output_suffix: str = "", +) -> None: + """Generates model output plots (maps) for an image and saves them to disk. + + Args: + output_folder: The folder in which the plots should be saved. + image_name: The name of the image for which the plots were generated. + bodypart_names: The names of bodyparts the model outputs. + bodyparts_to_plot: The names of bodyparts that should be plot. + image: Shape (height, width, channels). The image on which the model was run. + scmap: Shape (height, width, num_bodyparts). The scoremaps output by the model. + locref: Shape (height, width, num_bodyparts, 2). Optionally, the location + refinement fields output by the model. + paf: Shape (height, width, 2 * len(paf_graph)). Optionally, the part-affinity + fields output by the model. + paf_graph: Must be set if paf is not None. The PAF graph used to assemble. + paf_all_in_one: Whether to plot all PAFs in a single image. + paf_colormap: The colormap to use for the PAF maps. + output_suffix: The filename suffix for the maps to output. + """ + def _filename(map_name) -> str: + return f"{image_name}_{map_name}_{output_suffix}.png" + + to_plot = [ + i for i, bpt in enumerate(bodypart_names) if bpt in bodyparts_to_plot + ] + if len(to_plot) > 1: + map_ = scmap[:, :, to_plot].sum(axis=2) + elif len(to_plot) == 1 and len(bodypart_names) > 1: + map_ = scmap[:, :, to_plot[0]] + else: + map_ = scmap[..., 0] + + fig1, _ = visualize_scoremaps(image, map_) + fig1.savefig(output_folder / _filename("scmap")) + + if locref is not None: + if len(to_plot) > 1: + map_ = scmap[:, :, to_plot] + locref_x_ = locref[:, :, to_plot, 0] + locref_y_ = locref[:, :, to_plot, 1] + # only get the locref fields around their respective detections + locref_x_[map_ < 0.5] = 0 + locref_y_[map_ < 0.5] = 0 + # combine locrefs + map_ = map_.sum(axis=2) + locref_x_ = locref_x_.sum(axis=2) + locref_y_ = locref_y_.sum(axis=2) + elif len(to_plot) == 1 and len(bodypart_names) > 1: + locref_x_ = locref[:, :, to_plot[0], 0] + locref_y_ = locref[:, :, to_plot[0], 1] + else: + locref_x_ = locref[..., 0] + locref_y_ = locref[..., 1] + + fig2, _ = visualize_locrefs(image, map_, locref_x_, locref_y_) + fig2.savefig(output_folder / _filename("locref")) + + if paf is not None: + if paf_graph is None: + raise ValueError(f"When plotting the PAF, you must pass the ``paf_graph``") + + edge_list = [] + for n, edge in enumerate(paf_graph): + if any(ind in to_plot for ind in edge): + e0, e1 = edge + edge_list.append( + [(2 * n, 2 * n + 1), (bodypart_names[e0], bodypart_names[e1])] + ) + + if paf_all_in_one: + inds = [elem[0] for elem in edge_list] + n_inds = len(inds) + cmap = plt.cm.get_cmap(paf_colormap, n_inds) + colors = cmap(range(n_inds)) + fig3, _ = visualize_paf(image, paf[:, :, inds], colors=colors) + fig3.savefig(output_folder / _filename("paf")) + else: + for inds, names in edge_list: + fig3, _ = visualize_paf(image, paf[:, :, [inds]]) + fig3.savefig(output_folder / _filename(f"paf_{'_'.join(names)}")) + + plt.close("all") diff --git a/deeplabcut/core/weight_init.py b/deeplabcut/core/weight_init.py new file mode 100644 index 0000000000..8a6d374e8a --- /dev/null +++ b/deeplabcut/core/weight_init.py @@ -0,0 +1,206 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Classes to configure how to initialize model weights""" +from __future__ import annotations + +import warnings +from dataclasses import dataclass +from pathlib import Path + +import numpy as np + + +@dataclass +class WeightInitialization: + """Configures weights initialization when transfer learning or fine-tuning models + + Args: + snapshot_path: The path to the snapshot used to initialize pose model weights + when training a model. + detector_snapshot_path: The path to the snapshot used to initialize detector + weights when training a model. + dataset: Optionally, the dataset on which the snapshots were trained. Required + when fine-tuning SuperAnimal models. + with_decoder: Whether to load the decoder weights as well. + memory_replay: Only when ``with_decoder=True``. Whether to train the model with + memory replay, so that it predicts all SuperAnimal (or previous project) + bodyparts. + conversion_array: The mapping from SuperAnimal (or other project, on which the + weights were trained) to project bodyparts. Required when + `with_decoder=True`. + An array [7, 0, 1] means the project has 3 bodyparts, where the 1st bodypart + corresponds to the 8th bodypart in the pretrained model, the 2nd to the 1st + and the 3rd to the 2nd (as arrays are 0-indexed). + bodyparts: Optionally, the name of each bodypart entry in the conversion array. + """ + + snapshot_path: Path + detector_snapshot_path: Path | None = None + dataset: str | None = None + with_decoder: bool = False + memory_replay: bool = False + conversion_array: np.ndarray | None = None + bodyparts: list[str] | None = None + + def __post_init__(self): + if self.memory_replay and not self.with_decoder: + raise ValueError( + "You cannot train a model with memory replay if you do not keep the " + "decoder layers (``with_decoder=True``), but you passed " + "`memory_replay=True` and `with_decoder=False`. Please change your " + "WeightInitialization parameters." + ) + + if self.with_decoder and self.conversion_array is None: + raise ValueError( + f"You must specify a conversion_array to initialize decoder weights " + f"(``with_decoder=True``)." + ) + + if self.bodyparts is not None and self.conversion_array is None: + raise ValueError( + f"Specifying bodyparts should only be done when `with_decoder=True` and" + f" the conversion array is specified." + ) + + if self.conversion_array is not None and self.bodyparts is not None: + if not len(self.conversion_array) == len(self.bodyparts): + raise ValueError( + f"There must be the same number of elements in the bodyparts list " + "and conv. array; found {self.bodyparts}, {self.conversion_array}" + ) + + def to_dict(self) -> dict: + """Returns: the weight initialization as a dict""" + data = dict() + if self.dataset is not None: + data["dataset"] = self.dataset + + data["snapshot_path"] = str(self.snapshot_path) + if self.detector_snapshot_path is not None: + data["detector_snapshot_path"] = str(self.detector_snapshot_path) + + data["with_decoder"] = self.with_decoder + data["memory_replay"] = self.memory_replay + + if self.conversion_array is not None: + data["conversion_array"] = self.conversion_array.tolist() + + if self.bodyparts is not None: + data["bodyparts"] = self.bodyparts + + return data + + @staticmethod + def from_dict(data: dict) -> "WeightInitialization": + if "snapshot_path" not in data: + return WeightInitialization.from_dict_legacy(data) + + detector_snapshot_path = data.get("detector_snapshot_path") + if detector_snapshot_path is not None: + detector_snapshot_path = Path(detector_snapshot_path) + + conversion_array = data.get("conversion_array") + if conversion_array is not None: + conversion_array = np.array(conversion_array, dtype=int) + + return WeightInitialization( + snapshot_path=Path(data["snapshot_path"]), + detector_snapshot_path=detector_snapshot_path, + dataset=data.get("dataset"), + with_decoder=data["with_decoder"], + memory_replay=data["memory_replay"], + conversion_array=conversion_array, + bodyparts=data.get("bodyparts"), + ) + + @staticmethod + def from_dict_legacy(data: dict) -> "WeightInitialization": + """Deals with weight initialization that were created before 3.0.0rc5""" + import deeplabcut.pose_estimation_pytorch.modelzoo.utils as utils + + conversion_array = data.get("conversion_array") + if conversion_array is not None: + conversion_array = np.array(conversion_array, dtype=int) + + return WeightInitialization( + snapshot_path=utils.get_super_animal_snapshot_path( + dataset=data["dataset"], + model_name="hrnet_w32", + ), + detector_snapshot_path=utils.get_super_animal_snapshot_path( + dataset=data["dataset"], + model_name="fasterrcnn_resnet50_fpn_v2", + ), + with_decoder=data["with_decoder"], + memory_replay=data["memory_replay"], + conversion_array=conversion_array, + bodyparts=data.get("bodyparts"), + ) + + @staticmethod + def build( + cfg: dict, + super_animal: str, + model_name: str = "hrnet_w32", + detector_name: str = "fasterrcnn_resnet50_fpn_v2", + with_decoder: bool = False, + memory_replay: bool = False, + customized_pose_checkpoint: str | None = None, + customized_detector_checkpoint: str | None = None, + ) -> "WeightInitialization": + """Builds a WeightInitialization for a project + + `WeightInitialization.build` is deprecated and will be removed in a future + version of DeepLabCut. Please use `build_weight_init` from `deeplabcut.modelzoo` + instead. + + Args: + cfg: The project's configuration. + super_animal: The SuperAnimal model with which to initialize weights. + model_name: The name of the model architecture for which to load the weights + (defaults to "hrnet_w32" for backwards compatibility). + detector_name: The name of the detector architecture for which to load the + weights (defaults to "fasterrcnn_resnet50_fpn_v2" for backwards + compatibility). + with_decoder: Whether to load the decoder weights as well. If this is true, + a conversion table must be specified for the given SuperAnimal in the + project configuration file. See + ``deeplabcut.modelzoo.utils.create_conversion_table`` to create a + conversion table. + memory_replay: Only when ``with_decoder=True``. Whether to train the model + with memory replay, so that it predicts all SuperAnimal bodyparts. + customized_pose_checkpoint: A customized SuperAnimal pose checkpoint, as an + alternative to the Hugging Face one + customized_detector_checkpoint: A customized SuperAnimal detector + checkpoint, as an alternative to the Hugging Face one + + Returns: + The built WeightInitialization. + """ + from deeplabcut.modelzoo import build_weight_init + deprecation_warning = ( + "The `WeightInitialization.build` is deprecated and will be removed in a " + "future version of DeepLabCut. Please use `build_weight_init` from " + "`deeplabcut.modelzoo` instead." + ) + warnings.warn(deprecation_warning, DeprecationWarning) + + return build_weight_init( + cfg, + super_animal, + model_name, + detector_name, + with_decoder, + memory_replay, + customized_pose_checkpoint, + customized_detector_checkpoint, + ) diff --git a/deeplabcut/create_project/demo_data.py b/deeplabcut/create_project/demo_data.py index c495c89ba6..9e53eb7770 100644 --- a/deeplabcut/create_project/demo_data.py +++ b/deeplabcut/create_project/demo_data.py @@ -13,10 +13,15 @@ from pathlib import Path import deeplabcut +from deeplabcut.core.engine import Engine from deeplabcut.utils import auxiliaryfunctions -def load_demo_data(config, createtrainingset=True): +def load_demo_data( + config: str, + createtrainingset: bool = True, + engine: Engine = Engine.PYTORCH, +): """ Loads the demo data -- subset from trail-tracking data in Mathis et al. 2018. When loading, it sets paths correctly to run this project on your system @@ -29,6 +34,9 @@ def load_demo_data(config, createtrainingset=True): createtrainingset : bool Boolean variable indicating if a training set shall be created. + engine: Engine + The Engine to create the training set for if a training set shall be created. + Example -------- >>> deeplabcut.load_demo_data('config.yaml') @@ -40,7 +48,7 @@ def load_demo_data(config, createtrainingset=True): transform_data(config) if createtrainingset: print("Loaded, now creating training data...") - deeplabcut.create_training_dataset(config, num_shuffles=1) + deeplabcut.create_training_dataset(config, num_shuffles=1, engine=engine) def transform_data(config): diff --git a/deeplabcut/create_project/modelzoo.py b/deeplabcut/create_project/modelzoo.py index 76679f1867..7b853f9542 100644 --- a/deeplabcut/create_project/modelzoo.py +++ b/deeplabcut/create_project/modelzoo.py @@ -13,13 +13,32 @@ from pathlib import Path import yaml - -import deeplabcut -from deeplabcut.utils import auxiliaryfunctions +from dlclibrary import get_available_detectors from dlclibrary.dlcmodelzoo.modelzoo_download import ( download_huggingface_model, MODELOPTIONS, + get_available_datasets, + get_available_models, +) + +import deeplabcut +from deeplabcut import Engine +from deeplabcut.core.config import read_config_as_dict, write_config +from deeplabcut.generate_training_dataset.metadata import ( + TrainingDatasetMetadata, + ShuffleMetadata, + DataSplit, +) +from deeplabcut.generate_training_dataset.trainingsetmanipulation import ( + MakeInference_yaml, +) +from deeplabcut.modelzoo.utils import get_super_animal_project_cfg +from deeplabcut.pose_estimation_pytorch.config.make_pose_config import ( + add_metadata, + make_pytorch_test_config, ) +from deeplabcut.pose_estimation_pytorch.modelzoo.utils import load_super_animal_config +from deeplabcut.utils import auxiliaryfunctions Modeloptions = MODELOPTIONS # backwards compatibility for COLAB NOTEBOOK @@ -96,17 +115,22 @@ def create_pretrained_human_project( def create_pretrained_project( - project, - experimenter, - videos, - model="full_human", - working_directory=None, - copy_videos=False, - videotype="", - analyzevideo=True, - filtered=True, - createlabeledvideo=True, - trainFraction=None, + project: str, + experimenter: str, + videos: list[str], + model: str | None = None, + working_directory: str | None = None, + copy_videos: bool = False, + videotype: str = "", + analyzevideo: bool = True, + filtered: bool = True, + createlabeledvideo: bool = True, + trainFraction: float | None = None, + engine: Engine = Engine.PYTORCH, + multi_animal: bool = False, + individuals: list[str] | None = None, + net_name: str | None = None, + detector_name: str | None = None, ): """ Creates a new project directory, sub-directories and a basic configuration file. @@ -124,43 +148,408 @@ def create_pretrained_project( experimenter : string String containing the name of the experimenter. - model: string, options see http://www.mousemotorlab.org/dlc-modelzoo - Current option and default: 'full_human' Creates a demo human project and analyzes a video with ResNet 101 weights pretrained on MPII Human Pose. This is from the DeeperCut paper - by Insafutdinov et al. https://arxiv.org/abs/1605.03170 Please make sure to cite it too if you use this code! + model: string | None, default = None, + The model / dataset to use as basis for the project. + If None, the default model / dataset for the selected engine will be used. - videos : list + videos : list[string] A list of string containing the full paths of the videos to include in the project. - working_directory : string, optional - The directory where the project will be created. The default is the ``current working directory``; if provided, it must be a string. + working_directory : string, optional, default = None + The directory where the project will be created. If None - the current working directory will be used. - copy_videos : bool, optional ON WINDOWS: TRUE is often necessary! - If this is set to True, the videos are copied to the ``videos`` directory. If it is False,symlink of the videos are copied to the project/videos directory. The default is ``False``; if provided it must be either - ``True`` or ``False``. + copy_videos : bool, optional, default = False, + If this is set to True, the videos are copied to the ``videos`` directory. + If it is False, symlink of the videos are copied to the project/videos directory. + Note: on Windows: True is often necessary! - analyzevideo " bool, optional - If true, then the video is analyzed and a labeled video is created. If false, then only the project will be created and the weights downloaded. You can then access them + analyzevideo: bool, optional + If true, then the video is analyzed and a labeled video is created. + If false, then only the project will be created and the weights downloaded. - filtered: bool, default false - Boolean variable indicating if filtered pose data output should be plotted rather than frame-by-frame predictions. - Filtered version can be calculated with deeplabcut.filterpredictions + filtered: bool, default True + Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions. + Filtered version can be calculated with deeplabcut.filterpredictions() - trainFraction: By default value from *new* projects. (0.95) + createlabeledvideo: bool, default True, + Specifies if a labeled video needs to be created. + + trainFraction: float|None, default = None. Fraction that will be used in dlc-model/trainingset folder name. + If None - default value (0.95) from new projects will be used. + + engine: Engine, default Engine.PYTORCH, + engine on which the pretrained weights are based + + multi_animal: bool = False, + Specifies if the project is single or multi-animal. + Implemented only for Pytorch-based models. + + individuals: list[str] | None = None, + Only if multianimal is True. + Defines the names of the individuals. + + net_name: str | None, default = None, + Valid only if using Pytorch engine. + Name of the pose model on which the superanimal dataset has been trained on. + If None - "hrnet_w32" will be used as default. + + detector_name: str | None, default = None, + Valid only if using Pytorch engine. + Name of the detector model on which the superanimal dataset has been trained on. + If None - "fasterrcnn_resnet50_fpn_v2" will be used as default. Example -------- Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi - >>> deeplabcut.create_pretrained_project('humanstrokestudy','Linus',['/data/videos/homosapiens1.avi'], copy_videos=False) + >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False) Loading full_cat model and analyzing video "felixfeliscatus3.avi" - >>> deeplabcut.create_pretrained_project('humanstrokestudy','Linus',['/data/videos/felixfeliscatus3.avi'], model='full_cat') + >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF) Windows: - >>> deeplabcut.create_pretrained_project('humanstrokestudy','Bill',[r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'],r'C:\yourusername\analysis\project' copy_videos=True) + >>> deeplabcut.create_pretrained_project("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True) Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ ) + """ + if engine == Engine.TF: + return create_pretrained_project_tensorflow( + project=project, + experimenter=experimenter, + videos=videos, + model=model, + working_directory=working_directory, + copy_videos=copy_videos, + videotype=videotype, + analyzevideo=analyzevideo, + filtered=filtered, + createlabeledvideo=createlabeledvideo, + trainFraction=trainFraction, + ) + elif engine == Engine.PYTORCH: + return create_pretrained_project_pytorch( + project=project, + experimenter=experimenter, + videos=videos, + dataset=model, + working_directory=working_directory, + copy_videos=copy_videos, + video_type=videotype, + analyze_video=analyzevideo, + filtered=filtered, + create_labeled_video=createlabeledvideo, + train_fraction=trainFraction, + multi_animal=multi_animal, + individuals=individuals, + net_name=net_name, + detector_name=detector_name, + ) + + raise NotImplementedError(f"This function is not implemented for {engine}") + + +def create_pretrained_project_pytorch( + project: str, + experimenter: str, + videos: list[str], + dataset: str | None = None, + working_directory: str | None = None, + copy_videos: bool = False, + video_type: str | None = None, + analyze_video: bool = True, + filtered: bool = True, + create_labeled_video: bool = True, + train_fraction: float | None = None, + multi_animal: bool = False, + individuals: list[str] | None = None, + net_name: str | None = None, + detector_name: str | None = None, +): + """ + Method used specifically for Pytorch-based ModelZoo models. + + Creates a new project directory, sub-directories and a basic configuration file. + Change its parameters to your projects need. + + The project will also be initialized with a pre-trained model from the DeepLabCut model zoo! + + http://modelzoo.deeplabcut.org + + Parameters + ---------- + project : string + String containing the name of the project. + + experimenter : string + String containing the name of the experimenter. + + dataset: string|None, default = None, + The superanimal dataset to use as basis for the project. + If not specified - superanimal_quadruped will be used by default. + + videos : list[string] + A list of string containing the full paths of the videos to include in the project. + + working_directory : string, optional, default = None + The directory where the project will be created. If None - the current working directory will be used. + + copy_videos : bool, optional, default = False, + If this is set to True, the videos are copied to the ``videos`` directory. + If it is False, symlink of the videos are copied to the project/videos directory. + Note: on Windows: True is often necessary! + + analyze_video: bool, optional + If true, then the video is analyzed and a labeled video is created. + If false, then only the project will be created and the weights downloaded. + + filtered: bool, default True + Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions. + Filtered version can be calculated with deeplabcut.filterpredictions() + + create_labeled_video: bool, default True + Specifies if a labeled video needs to be created. + + train_fraction: float|None, default = None. + Fraction that will be used in dlc-model/trainingset folder name. + If None - default value (0.95) from new projects will be used. + + multi_animal: bool = False, + Specifies if the project is single or multi-animal + + individuals: list[str]|None = None, + Only if multianimal is True. + Defines the names of the individuals. + + net_name: str | None, default = None, + Valid only if using Pytorch engine. + Name of the pose model on which the superanimal dataset has been trained on. + If None - "hrnet_w32" will be used as default. + detector_name: str | None, default = None, + Valid only if using Pytorch engine. + Name of the detector model on which the superanimal dataset has been trained on. + If None - "fasterrcnn_resnet50_fpn_v2" will be used as default. + + Example + -------- + Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi + >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False) + + Loading full_cat model and analyzing video "felixfeliscatus3.avi" + >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF) + + Windows: + >>> deeplabcut.create_pretrained_project_pytorch("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True) + Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ ) """ + # Check arguments + if not dataset: + dataset = "superanimal_quadruped" + + if not net_name: + net_name = "hrnet_w32" + + # Currently, all Pytorch Superanimal models are Top-Down. + if not detector_name: + detector_name = "fasterrcnn_resnet50_fpn_v2" + + if dataset not in get_available_datasets(): + raise ValueError( + f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}" + ) + + if net_name not in get_available_models(dataset): + raise ValueError( + f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}" + ) + + if detector_name not in get_available_detectors(dataset): + raise ValueError( + f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}" + ) + + # Create project + cfg_path = deeplabcut.create_new_project( + project=project, + experimenter=experimenter, + videos=videos, + working_directory=working_directory, + copy_videos=copy_videos, + videotype=video_type, + multianimal=multi_animal, + individuals=individuals, + ) + + # Edits to do to the project config + cfg_edits = {} + if train_fraction is not None: + cfg_edits["TrainingFraction"] = [train_fraction] + super_animal_project_cfg = get_super_animal_project_cfg(dataset) + super_animal_bodyparts = super_animal_project_cfg.get("bodyparts") + super_animal_skeleton = super_animal_project_cfg.get("skeleton") + cfg_edits["skeleton"] = super_animal_skeleton + if multi_animal: + cfg_edits["multianimalbodyparts"] = super_animal_bodyparts + else: + cfg_edits["bodyparts"] = super_animal_bodyparts + auxiliaryfunctions.edit_config(cfg_path, edits=cfg_edits) + + # Create the shuffle train and test directories + config = read_config_as_dict(cfg_path) + shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder( + trainFraction=config["TrainingFraction"][0], + shuffle=1, + cfg=config, + engine=Engine.PYTORCH, + ) + train_dir = shuffle_dir / "train" + test_dir = shuffle_dir / "test" + train_dir.mkdir(parents=True, exist_ok=True) + test_dir.mkdir(parents=True, exist_ok=True) + + # Download the weights and put them into appropriate directory + print("Downloading weights...") + super_animal_detector_name = f"{dataset}_{detector_name}" + new_detector_name = "snapshot-detector-000.pt" + download_huggingface_model( + model_name=super_animal_detector_name, + target_dir=str(train_dir), + rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name}, + ) + super_animal_model_name = f"{dataset}_{net_name}" + new_snapshot_name = "snapshot-000.pt" + download_huggingface_model( + model_name=super_animal_model_name, + target_dir=str(train_dir), + rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name}, + ) + + # Create pytorch_config.yaml + train_cfg_path = train_dir / "pytorch_config.yaml" + pytorch_config = load_super_animal_config( + super_animal=dataset, + model_name=net_name, + detector_name=detector_name, + ) + pytorch_config = add_metadata(config, pytorch_config, train_cfg_path) + pytorch_config["resume_training_from"] = str(train_dir / new_snapshot_name) + pytorch_config["detector"]["resume_training_from"] = str( + train_dir / new_detector_name + ) + write_config(train_cfg_path, pytorch_config) + + # Create test pose_cfg.yaml + test_cfg_path = test_dir / "pose_cfg.yaml" + make_pytorch_test_config( + model_config=pytorch_config, test_config_path=test_cfg_path, save=True + ) + + # Create inference_cfg.yaml if needed + if multi_animal: + inference_cfg_path = test_dir / "inference_cfg.yaml" + _create_inference_config(inference_cfg_path, config) + + # Create metadata.yaml with shuffle info in training-data directory + _create_training_datasets_metadata(config, shuffle_dir.name, Engine.PYTORCH) + + # Process the videos + _process_videos( + cfg_path=cfg_path, + video_type=video_type, + analyze_video=analyze_video, + filtered=filtered, + create_labeled_video=create_labeled_video, + ) + return cfg_path, str(train_cfg_path) + + +def _create_inference_config(inference_cfg_path: str | Path, project_cfg: dict): + inf_updates = dict( + minimalnumberofconnections=int(len(project_cfg["multianimalbodyparts"]) / 2), + topktoretain=len(project_cfg["individuals"]), + withid=project_cfg.get("identity", False), + ) + default_inf_path = ( + Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml" + ) + MakeInference_yaml(inf_updates, inference_cfg_path, default_inf_path) + + +def create_pretrained_project_tensorflow( + project: str, + experimenter: str, + videos: list[str], + model: str | None = None, + working_directory: str | None = None, + copy_videos: bool = False, + videotype: str = "", + analyzevideo: bool = True, + filtered: bool = True, + createlabeledvideo: bool = True, + trainFraction: float | None = None, +): + """ + Method used specifically for Tensorflow-based ModelZoo models. + + Creates a new project directory, sub-directories and a basic configuration file. + Change its parameters to your projects need. + + The project will also be initialized with a pre-trained model from the DeepLabCut model zoo! + + http://modelzoo.deeplabcut.org + + Parameters + ---------- + project : string + String containing the name of the project. + + experimenter : string + String containing the name of the experimenter. + + model: string|None, default = None, + The model / dataset to use as basis for the project. + If not specified - full_human will be used by default. + + videos : list[string] + A list of string containing the full paths of the videos to include in the project. + + working_directory : string, optional, default = None + The directory where the project will be created. If None - the current working directory will be used. + + copy_videos : bool, optional, default = False, + If this is set to True, the videos are copied to the ``videos`` directory. + If it is False, symlink of the videos are copied to the project/videos directory. + Note: on Windows: True is often necessary! + + analyzevideo: bool, optional + If true, then the video is analyzed and a labeled video is created. + If false, then only the project will be created and the weights downloaded. + + filtered: bool, default True + Indicates if filtered pose data output should be plotted rather than frame-by-frame predictions. + Filtered version can be calculated with deeplabcut.filterpredictions() + + createlabeledvideo: bool, default True + Specifies if a labeled video needs to be created. + + trainFraction: float|None, default = None. + Fraction that will be used in dlc-model/trainingset folder name. + If None - default value (0.95) from new projects will be used. + + Example + -------- + Linux/MacOs loading full_human model and analyzing video /homosapiens1.avi + >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Linus", ["/data/videos/homosapiens1.avi"], copy_videos=False) + + Loading full_cat model and analyzing video "felixfeliscatus3.avi" + >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Linus", ["/data/videos/felixfeliscatus3.avi"], model="full_cat", engine=Engine.TF) + + Windows: + >>> deeplabcut.create_pretrained_project_tensorflow("humanstrokestudy", "Bill", [r'C:\yourusername\rig-95\Videos\reachingvideo1.avi'], r'C:\yourusername\analysis\project', copy_videos=True) + Users must format paths with either: r'C:\ OR 'C:\\ <- i.e. a double backslash \ \ ) + """ + if not model: + model = "full_human" + if model in MODELOPTIONS: cwd = os.getcwd() @@ -300,23 +689,71 @@ def create_pretrained_project( MakeTest_pose_yaml(pose_cfg, keys2save, path_test_config) - video_dir = os.path.join(config["project_path"], "videos") - if analyzevideo == True: - print("Analyzing video...") - deeplabcut.analyze_videos(cfg, [video_dir], videotype, save_as_csv=True) - - if createlabeledvideo == True: - if filtered: - deeplabcut.filterpredictions(cfg, [video_dir], videotype) + _create_training_datasets_metadata(config, modelfoldername.name, Engine.TF) - print("Plotting results...") - deeplabcut.create_labeled_video( - cfg, [video_dir], videotype, draw_skeleton=True, filtered=filtered - ) - deeplabcut.plot_trajectories(cfg, [video_dir], videotype, filtered=filtered) + _process_videos( + cfg_path=cfg, + video_type=videotype, + analyze_video=analyzevideo, + filtered=filtered, + create_labeled_video=createlabeledvideo, + ) os.chdir(cwd) return cfg, path_train_config else: return "N/A", "N/A" + + +def _create_training_datasets_metadata( + config: dict, shuffle_dir_name: str, engine: Engine +): + # First create the metadata object + metadata = TrainingDatasetMetadata.create(config) + + # Create a new shuffle with TensorFlow engine + new_shuffle = ShuffleMetadata( + name=shuffle_dir_name, + train_fraction=config["TrainingFraction"][0], + index=1, + engine=engine, + split=DataSplit(train_indices=(), test_indices=()), + ) + + # Add the shuffle to metadata + metadata = metadata.add(new_shuffle) + + # Save the metadata + metadata.save() + + return metadata + + +def _process_videos( + cfg_path: str | Path, + video_type: str = "", + analyze_video: bool = True, + filtered: bool = True, + create_labeled_video: bool = True, +): + cfg_path = str(cfg_path) + video_dir = Path(cfg_path).parent / "videos" + + if analyze_video: + print("Analyzing video...") + deeplabcut.analyze_videos( + cfg_path, [video_dir], videotype=video_type, save_as_csv=True + ) + + if create_labeled_video: + if filtered: + deeplabcut.filterpredictions(cfg_path, [video_dir], video_type) + + print("Plotting results...") + deeplabcut.create_labeled_video( + cfg_path, [video_dir], video_type, draw_skeleton=True, filtered=filtered + ) + deeplabcut.plot_trajectories( + cfg_path, [video_dir], video_type, filtered=filtered + ) diff --git a/deeplabcut/create_project/new.py b/deeplabcut/create_project/new.py index 0fc7a06139..f812d89bb9 100644 --- a/deeplabcut/create_project/new.py +++ b/deeplabcut/create_project/new.py @@ -19,13 +19,14 @@ def create_new_project( - project, - experimenter, - videos, - working_directory=None, - copy_videos=False, - videotype="", - multianimal=False, + project: str, + experimenter: str, + videos: list[str], + working_directory: str | None = None, + copy_videos: bool = False, + videotype: str = "", + multianimal: bool = False, + individuals: list[str] | None = None, ): r"""Create the necessary folders and files for a new project. @@ -58,6 +59,11 @@ def create_new_project( multianimal: bool, optional. Default: False. For creating a multi-animal project (introduced in DLC 2.2) + individuals: list[str]|None = None, + Relevant only if multianimal is True. + list of individuals to be used in the project configuration. + If None - defaults to ['individual1', 'individual2', 'individual3'] + Returns ------- str @@ -143,7 +149,9 @@ def create_new_project( # Check if it is a folder if os.path.isdir(i): vids_in_dir = [ - os.path.join(i, vp) for vp in os.listdir(i) if vp.endswith(videotype) + os.path.join(i, vp) + for vp in os.listdir(i) + if vp.lower().endswith(videotype) ] vids = vids + vids_in_dir if len(vids_in_dir) == 0: @@ -239,7 +247,11 @@ def create_new_project( cfg_file, ruamelFile = auxiliaryfunctions.create_config_template(multianimal) cfg_file["multianimalproject"] = multianimal cfg_file["identity"] = False - cfg_file["individuals"] = ["individual1", "individual2", "individual3"] + cfg_file["individuals"] = ( + individuals + if individuals + else ["individual1", "individual2", "individual3"] + ) cfg_file["multianimalbodyparts"] = ["bodypart1", "bodypart2", "bodypart3"] cfg_file["uniquebodyparts"] = [] cfg_file["bodyparts"] = "MULTI!" @@ -272,6 +284,7 @@ def create_new_project( cfg_file["TrainingFraction"] = [0.95] cfg_file["iteration"] = 0 cfg_file["snapshotindex"] = -1 + cfg_file["detector_snapshotindex"] = -1 cfg_file["x1"] = 0 cfg_file["x2"] = 640 cfg_file["y1"] = 277 @@ -279,6 +292,7 @@ def create_new_project( cfg_file["batch_size"] = ( 8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242 ) + cfg_file["detector_batch_size"] = 1 cfg_file["corner2move2"] = (50, 50) cfg_file["move2corner"] = True cfg_file["skeleton_color"] = "black" diff --git a/deeplabcut/generate_training_dataset/__init__.py b/deeplabcut/generate_training_dataset/__init__.py index 05b0092d49..60eac17c1d 100644 --- a/deeplabcut/generate_training_dataset/__init__.py +++ b/deeplabcut/generate_training_dataset/__init__.py @@ -13,3 +13,8 @@ from deeplabcut.generate_training_dataset.frame_extraction import * from deeplabcut.generate_training_dataset.trainingsetmanipulation import * from deeplabcut.generate_training_dataset.multiple_individuals_trainingsetmanipulation import * +from deeplabcut.generate_training_dataset.metadata import ( + DataSplit, + ShuffleMetadata, + TrainingDatasetMetadata, +) diff --git a/deeplabcut/generate_training_dataset/metadata.py b/deeplabcut/generate_training_dataset/metadata.py new file mode 100644 index 0000000000..a26a5eadda --- /dev/null +++ b/deeplabcut/generate_training_dataset/metadata.py @@ -0,0 +1,481 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""File containing methods to load and parse shuffle metadata""" +from __future__ import annotations + +import logging +import pickle +import re +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +from ruamel.yaml import YAML + +from deeplabcut.core.engine import Engine +from deeplabcut.utils import auxiliaryfunctions + + +@dataclass(frozen=True) +class DataSplit: + """Class representing the metadata for a shuffle""" + train_indices: tuple[int, ...] + test_indices: tuple[int, ...] + + def __post_init__(self) -> None: + """ + Raises: + ValueError if the indices are not sorted in increasing + """ + for indices in [self.train_indices, self.test_indices]: + idx = np.array(indices) + if not np.all(idx[:-1] < idx[1:]): + raise RuntimeError( + f"The training and test indices in a data split must be sorted in " + f"strictly ascending order." + ) + + +@dataclass(frozen=True) +class ShuffleMetadata: + """Class representing the metadata for a shuffle""" + name: str + train_fraction: float + index: int + engine: Engine + split: DataSplit | None + + def load_split(self, cfg: dict, trainset_path: Path) -> "ShuffleMetadata": + """Loads the data split for this shuffle + + Args: + cfg: the config for the DeepLabCut project + trainset_path: the path to the training dataset folder + + Returns: + a new instance with the data split defined + """ + _, doc_path = auxiliaryfunctions.get_data_and_metadata_filenames( + trainset_path, self.train_fraction, self.index, cfg + ) + if not Path(doc_path).exists(): + raise ValueError( + f"Could not load the metadata file for {self} as {doc_path} does not " + f"exist. If you deleted the shuffle, you also need to delete the " + f"shuffle from metadata.yaml or recreate the metadata.yaml file." + ) + + with open(doc_path, "rb") as f: + _, train_idx, test_idx, _ = pickle.load(f) + return ShuffleMetadata( + name=self.name, + train_fraction=self.train_fraction, + index=self.index, + engine=self.engine, + split=DataSplit( + train_indices=tuple(sorted([int(idx) for idx in train_idx])), + test_indices=tuple(sorted([int(idx) for idx in test_idx])), + ) + ) + + +@dataclass(frozen=True) +class TrainingDatasetMetadata: + """An immutable class containing the metadata for a dataset + + When creating a new "training-datasets" folder (e.g., when creating the first + training set for a project, or when creating the first training for a given + iteration of a project), TrainingDatasetMetadata.create(cfg) should be called when + the "training-datasets" folder is still empty. + + For existing projects (created with DeepLabCut < 3.0), calling + TrainingDatasetMetadata.create(cfg) will go over documentation data for all existing + shuffles in the training-datasets folder and add them to a new metadata instance. + All shuffles will be given Engine.TF as an engine. + + Examples: + # Creating the metadata file for an existing project + config = "/data/my-dlc-project/config.yaml" + trainset_metadata = TrainingDatasetMetadata.create(config) + trainset_metadata.save() + + # Adding a new shuffle to the metadata file + config = "/data/my-dlc-project-2008-06-17/config.yaml" + trainset_metadata = TrainingDatasetMetadata.load(config) + new_shuffle = ShuffleMetadata( + name="my-dlc-projectJun17-trainset60shuffle5", + train_fraction=0.6, + index=5, + engine=compat.Engine.PYTORCH, + split=DataSplit(train_indices=(1, 3, 4), test_indices=(0, 2)), + ) + trainset_metadata = trainset_metadata.add(new_shuffle) + trainset_metadata.save() # saves to disk + """ + project_config: dict + shuffles: tuple[ShuffleMetadata, ...] + file_header: tuple[str] = ( + "# This file is automatically generated - DO NOT EDIT", + "# It contains the information about the shuffles created for the dataset", + "---", + ) + + def __post_init__(self) -> None: + """ + Raises: + ValueError if the indices are not sorted in increasing order + """ + indices = [[s.train_fraction, s.index] for s in self.shuffles] + for (frac1, idx1), (frac2, idx2) in zip(indices[:-1], indices[1:]): + if not (frac1 < frac2 or (frac1 == frac2 and idx1 < idx2)): + raise RuntimeError( + "The shuffles given must be sorted in order of ascending training " + f"fraction and index. Found {self.shuffles}" + ) + + def add( + self, + shuffle: ShuffleMetadata, + overwrite: bool = False, + ) -> TrainingDatasetMetadata: + """ + Adds a new shuffle to the metadata file + + Args: + shuffle: the shuffle to add + overwrite: if a shuffle with the same index is already stored in the + metadata file, whether to overwrite it + + Returns: + A new instance of TrainingDatasetMetadata with updated shuffles + + Raises: + ValueError: if overwrite=False and there is already a shuffle with the given + index in the metadata file. + """ + existing_indices = [ + s.index for s in self.shuffles if s.train_fraction == shuffle.train_fraction + ] + if shuffle.index in existing_indices: + if not overwrite: + raise RuntimeError( + f"Cannot add {shuffle} to the meta: a shuffle with index " + f"{shuffle.index} and train_fraction {shuffle.train_fraction} " + f"already exists: {self.shuffles}." + ) + + existing_shuffles = [ + s + for s in self.shuffles + if (s.index != shuffle.index or s.train_fraction != shuffle.train_fraction) + ] + shuffles = existing_shuffles + [shuffle] + return TrainingDatasetMetadata( + project_config=self.project_config, + shuffles=tuple(sorted(shuffles, key=lambda s: (s.train_fraction, s.index))), + ) + + def get(self, trainset_index: int = 0, index: int = 0) -> ShuffleMetadata: + """ + Args: + trainset_index: the index of the trainset fraction as defined in config.yaml + index: the index of the shuffle + + Returns: + the shuffle with the given trainset index and shuffle index + + Raises: + ValueError if the shuffle is not present in the metadata + """ + train_fraction = self.project_config["TrainingFraction"][trainset_index] + for shuffle in self.shuffles: + if ( + shuffle.train_fraction == train_fraction + and shuffle.index == index + ): + return shuffle + + raise ValueError( + f"Could not find a shuffle with trainingset fraction {train_fraction} and " + f"index {index}" + ) + + def save(self) -> None: + """Saves the training dataset metadata to disk""" + metadata = {"shuffles": {}} + data_splits: dict[DataSplit, int] = {} + trainset_path = self.path(self.project_config).parent + for s in self.shuffles: + if s.split is None: + s = s.load_split(cfg=self.project_config, trainset_path=trainset_path) + + split_index = data_splits.get(s.split) + if split_index is None: + split_index = len(data_splits) + 1 + data_splits[s.split] = split_index + + metadata["shuffles"][s.name] = { + "train_fraction": s.train_fraction, + "index": s.index, + "split": split_index, + "engine": s.engine.aliases[0], + } + + with open(self.path(self.project_config), "w") as file: + file.write("\n".join(self.file_header) + "\n") + YAML().dump(metadata, file) + + @staticmethod + def load( + config: str | Path | dict, + load_splits: bool = False, + ) -> TrainingDatasetMetadata: + """Loads the metadata from disk + + Args: + config: the config for the DeepLabCut project (or its path) + load_splits: whether to load the data split for each shuffle + """ + if isinstance(config, (str, Path)): + cfg = auxiliaryfunctions.read_config(config) + else: + cfg = config + + metadata_path = TrainingDatasetMetadata.path(cfg) + with open(metadata_path, "r") as file: + metadata = YAML(typ="safe", pure=True).load(file) + + shuffles = [] + for shuffle_name, shuffle_metadata in metadata["shuffles"].items(): + shuffle = ShuffleMetadata( + name=shuffle_name, + train_fraction=shuffle_metadata["train_fraction"], + index=shuffle_metadata["index"], + engine=Engine(shuffle_metadata["engine"]), + split=None, + ) + if load_splits: + shuffle = shuffle.load_split(cfg, metadata_path.parent) + + shuffles.append(shuffle) + + shuffles.sort(key=lambda s: (s.train_fraction, s.index)) + return TrainingDatasetMetadata(project_config=cfg, shuffles=tuple(shuffles)) + + @staticmethod + def create(config: str | Path | dict) -> TrainingDatasetMetadata: + """Function to create the metadata file + + Assumes that all existing shuffles use the TensorFlow engine, as this file + should have already been created for PyTorch shuffles. + + Args; + config: the config for the DeepLabCut project (or its path) + default_engine: the default engine to set for shuffles in the project + + Returns: + the metadata for the existing shuffles in the project + """ + if isinstance(config, (str, Path)): + cfg = auxiliaryfunctions.read_config(config) + else: + cfg = config + + trainset_path = TrainingDatasetMetadata.path(cfg).parent + if trainset_path.exists(): + shuffle_docs = [ + f + for f in trainset_path.iterdir() + if re.match(r"Documentation_data-.+shuffle[0-9]+\.pickle", f.name) + ] + else: + trainset_path.mkdir(parents=True) + shuffle_docs = [] + + prefix = cfg["Task"] + cfg["date"] + shuffles = [] + existing_splits: dict[tuple[tuple[int, ...], tuple[int, ...]], int] = {} + for doc_path in shuffle_docs: + index = int(doc_path.stem.split("shuffle")[-1]) + with open(doc_path, "rb") as f: + _, train_idx, test_idx, train_frac = pickle.load(f) + + engine = Engine.TF + train_idx = tuple(sorted([int(idx) for idx in train_idx])) + test_idx = tuple(sorted([int(idx) for idx in test_idx])) + split_idx = existing_splits.get((train_idx, test_idx)) + if split_idx is None: + split_idx = len(existing_splits) + 1 + existing_splits[(train_idx, test_idx)] = split_idx + + shuffles.append( + ShuffleMetadata( + name=f"{prefix}-trainset{int(100 * train_frac)}shuffle{index}", + train_fraction=train_frac, + index=index, + engine=engine, + split=DataSplit(train_indices=train_idx, test_indices=test_idx), + ) + ) + + shuffles = tuple(sorted(shuffles, key=lambda s: (s.train_fraction, s.index))) + return TrainingDatasetMetadata( + project_config=cfg, + shuffles=shuffles, + ) + + @staticmethod + def path(cfg: dict) -> Path: + """ + Args: + cfg: the config for the DeepLabCut project + + Returns: + the path to the training dataset metadata file + """ + meta_path = auxiliaryfunctions.get_training_set_folder(cfg) / "metadata.yaml" + return Path(cfg["project_path"]) / meta_path + + +def update_metadata( + cfg: dict, + train_fraction: float, + shuffle: int, + engine: Engine, + train_indices: list[int], + test_indices: list[int], + overwrite: bool = False, +) -> None: + """Updates the metadata for a training-dataset + + Args: + cfg: the config for the DeepLabCut project + train_fraction: the train_fraction of the new shuffle + shuffle: the index of the shuffle to add + engine: the engine for the shuffle + train_indices: the indices of images in the training set + test_indices: the indices of images in the test set + overwrite: whether to overwrite a shuffle with the same index and train fraction + if one exists + + Raises: + ValueError: if overwrite=False and there is already a shuffle with the given + index in the metadata file. + """ + prefix = cfg["Task"] + cfg["date"] + metadata = TrainingDatasetMetadata.load(cfg, load_splits=True) + new_shuffle = ShuffleMetadata( + name=f"{prefix}-trainset{int(100 * train_fraction)}shuffle{shuffle}", + train_fraction=train_fraction, + index=shuffle, + engine=engine, + split=DataSplit( + train_indices=tuple(sorted([int(i) for i in train_indices])), + test_indices=tuple(sorted([int(i) for i in test_indices])), + ) + ) + metadata = metadata.add(shuffle=new_shuffle, overwrite=overwrite) + metadata.save() + + +def get_shuffle_engine( + cfg: dict, + trainingsetindex: int, + shuffle: int, + modelprefix: str = "", +) -> Engine: + """ + Args: + cfg: the config for the DeepLabCut project + trainingsetindex: the training set index used + shuffle: the shuffle for which to get the engine + modelprefix: the model prefix, if there is one + + Returns: + the engine that the shuffle was created with + + Raises: + ValueError if the engine for the shuffle cannot be determined or the shuffle + doesn't exist + """ + if not TrainingDatasetMetadata.path(cfg).exists(): + metadata = TrainingDatasetMetadata.create(cfg) + metadata.save() + + metadata = TrainingDatasetMetadata.load(cfg) + shuffle_metadata = metadata.get(trainingsetindex, shuffle) + if modelprefix: + # try to get the engine by checking which models folder exists + engines = find_engines_from_model_folders( + cfg, trainingsetindex, shuffle, modelprefix + ) + if len(engines) == 0: + raise ValueError( + f"Couldn't find any shuffles with trainingsetindex={trainingsetindex}, " + f"shuffle={shuffle} and modelprefix={modelprefix}. Please check that " + f"such a shuffle is defined." + ) + + if len(engines) == 1: + return engines.pop() + + if shuffle_metadata.engine in engines: + engine = shuffle_metadata.engine + else: + engine = engines.pop() # take a random engine + + logging.warning( + f"Found multiple engines for trainingsetindex={trainingsetindex}, " + f"shuffle={shuffle} and modelprefix={modelprefix}. Using engine={engine}. " + f"To select another engine, please specify it in your API call." + ) + return engine + + return shuffle_metadata.engine + + +def find_engines_from_model_folders( + cfg: dict, + trainingsetindex: int, + shuffle: int, + modelprefix: str = "", +) -> set[Engine]: + """Determines which engines are used with a given shuffle. + + This method can be useful when using modelprefix, as the engine for a shuffle stored + under a "modelprefix" might not be the same as the base shuffle (for which the + engine is stored in the training-datasets folder). + + Args: + cfg: the config for the DeepLabCut project + trainingsetindex: the training set index used + shuffle: the shuffle for which to get the engine + modelprefix: the model prefix, if there is one + + Returns: + the engines for which a model folder exists for the given shuffle + """ + project_path = Path(cfg["project_path"]) + train_fraction = cfg["TrainingFraction"][trainingsetindex] + + existing_engines = set() + for engine in Engine: + expected_model_folder = project_path / auxiliaryfunctions.get_model_folder( + trainFraction=train_fraction, + shuffle=shuffle, + cfg=cfg, + engine=engine, + modelprefix=modelprefix, + ) + if expected_model_folder.exists(): + existing_engines.add(engine) + + return existing_engines diff --git a/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py b/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py index 10a41c4d6f..1095dfef6f 100755 --- a/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py +++ b/deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py @@ -8,6 +8,7 @@ # # Licensed under GNU Lesser General Public License v3.0 # +from __future__ import annotations import os import os.path @@ -19,6 +20,10 @@ import numpy as np from tqdm import tqdm +import deeplabcut.compat as compat +import deeplabcut.generate_training_dataset.metadata as metadata +from deeplabcut.core.engine import Engine +from deeplabcut.core.weight_init import WeightInitialization from deeplabcut.generate_training_dataset import ( merge_annotateddatasets, read_image_shape_fast, @@ -27,6 +32,7 @@ MakeTest_pose_yaml, MakeInference_yaml, pad_train_test_indices, + validate_shuffles, ) from deeplabcut.utils import ( auxiliaryfunctions, @@ -101,6 +107,7 @@ def create_multianimaltraining_dataset( Shuffles=None, windows2linux=False, net_type=None, + detector_type=None, numdigits=2, crop_size=(400, 400), crop_sampling="hybrid", @@ -109,15 +116,20 @@ def create_multianimaltraining_dataset( testIndices=None, n_edges_threshold=105, paf_graph_degree=6, + userfeedback: bool = True, + weight_init: WeightInitialization | None = None, + engine: Engine | None = None, ): """ - Creates a training dataset for multi-animal datasets. Labels from all the extracted frames are merged into a single .h5 file.\n + Creates a training dataset for multi-animal datasets. Labels from all the extracted + frames are merged into a single .h5 file.\n Only the videos included in the config file are used to create this dataset.\n - [OPTIONAL] Use the function 'add_new_videos' at any stage of the project to add more videos to the project. + [OPTIONAL] Use the function 'add_new_videos' at any stage of the project to add more + videos to the project. Important differences to standard: - stores coordinates with numdigits as many digits - - creates + Parameter ---------- config : string @@ -130,17 +142,53 @@ def create_multianimaltraining_dataset( Alternatively the user can also give a list of shuffles (integers!). net_type: string - Type of networks. Currently resnet_50, resnet_101, and resnet_152, efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, - efficientnet-b4, efficientnet-b5, and efficientnet-b6 as well as dlcrnet_ms5 are supported (not the MobileNets!). - See Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1 + Type of networks. The options available depend on which engine is used. See + Lauer et al. 2021 https://www.biorxiv.org/content/10.1101/2021.04.30.442096v1 + Currently supported options are: + TensorFlow + * ``resnet_50`` + * ``resnet_101`` + * ``resnet_152`` + * ``efficientnet-b0`` + * ``efficientnet-b1`` + * ``efficientnet-b2`` + * ``efficientnet-b3`` + * ``efficientnet-b4`` + * ``efficientnet-b5`` + * ``efficientnet-b6`` + PyTorch (call ``deeplabcut.pose_estimation_pytorch.available_models()`` for + a complete list) + * ``resnet_50`` + * ``resnet_101`` + * ``dekr_w18`` + * ``dekr_w32`` + * ``dekr_w48`` + * ``top_down_resnet_50`` + * ``top_down_resnet_101`` + * ``top_down_hrnet_w18`` + * ``top_down_hrnet_w32`` + * ``top_down_hrnet_w48`` + * ``animaltokenpose_base`` + + detector_type: string, optional, default=None + Only for the PyTorch engine. + When passing creating shuffles for top-down models, you can specify which + detector you want. If the detector_type is None, the ```ssdlite``` will be used. + The list of all available detectors can be obtained by calling + ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported options: + * ``ssdlite`` + * ``fasterrcnn_mobilenet_v3_large_fpn`` + * ``fasterrcnn_resnet50_fpn_v2`` numdigits: int, optional crop_size: tuple of int, optional + Only for the TensorFlow engine. Dimensions (width, height) of the crops for data augmentation. Default is 400x400. crop_sampling: str, optional + Only for the TensorFlow engine. Crop centers sampling method. Must be either: "uniform" (randomly over the image), "keypoints" (randomly over the annotated keypoints), @@ -149,6 +197,7 @@ def create_multianimaltraining_dataset( Default is "hybrid". paf_graph: list of lists, or "config" optional (default=None) + Only for the TensorFlow engine. If not None, overwrite the default complete graph. This is useful for advanced users who already know a good graph, or simply want to use a specific one. Note that, in that case, the data-driven selection procedure upon model evaluation will be skipped. @@ -163,11 +212,27 @@ def create_multianimaltraining_dataset( List of one or multiple lists containing test indexes. n_edges_threshold: int, optional (default=105) + Only for the TensorFlow engine. Number of edges above which the graph is automatically pruned. paf_graph_degree: int, optional (default=6) + Only for the TensorFlow engine. Degree of paf_graph when automatically pruning it (before training). + userfeedback: bool, optional, default=True + If ``False``, all requested train/test splits are created (no matter if they + already exist). If you want to assure that previous splits etc. are not + overwritten, set this to ``True`` and you will be asked for each split. + + weight_init: WeightInitialisation, optional, default=None + PyTorch engine only. Specify how model weights should be initialized. The + default mode uses transfer learning from ImageNet weights. + + engine: Engine, optional + Whether to create a pose config for a Tensorflow or PyTorch model. Defaults to + the value specified in the project configuration file. If no engine is specified + for the project, defaults to ``deeplabcut.compat.DEFAULT_ENGINE``. + Example -------- >>> deeplabcut.create_multianimaltraining_dataset('/analysis/project/reaching-task/config.yaml',num_shuffles=1) @@ -202,6 +267,11 @@ def create_multianimaltraining_dataset( full_training_path = Path(project_path, trainingsetfolder) auxiliaryfunctions.attempt_to_make_folder(full_training_path, recursive=True) + # Create the trainset metadata file, if it doesn't yet exist + if not metadata.TrainingDatasetMetadata.path(cfg).exists(): + trainset_metadata = metadata.TrainingDatasetMetadata.create(cfg) + trainset_metadata.save() + Data = merge_annotateddatasets(cfg, full_training_path) if Data is None: return @@ -209,13 +279,21 @@ def create_multianimaltraining_dataset( if net_type is None: # loading & linking pretrained models net_type = cfg.get("default_net_type", "dlcrnet_ms5") - elif not any(net in net_type for net in ("resnet", "eff", "dlc", "mob")): - raise ValueError(f"Unsupported network {net_type}.") + + # load the engine to use to create the shuffle + if engine is None: + engine = compat.get_project_engine(cfg) + + if not ( + any(net in net_type for net in ("resnet", "eff", "dlc", "mob")) + or engine == Engine.PYTORCH + ): + raise ValueError(f"Unsupported network {net_type} for engine {engine}.") multi_stage = False ### dlcnet_ms5: backbone resnet50 + multi-fusion & multi-stage module ### dlcr101_ms5/dlcr152_ms5: backbone resnet101/152 + multi-fusion & multi-stage module - if all(net in net_type for net in ("dlcr", "_ms5")): + if all(net in net_type for net in ("dlcr", "_ms5")) and engine != Engine.PYTORCH: num_layers = re.findall("dlcr([0-9]*)", net_type)[0] if num_layers == "": num_layers = 50 @@ -272,12 +350,13 @@ def create_multianimaltraining_dataset( # Loading the encoder (if necessary downloading from TF) dlcparent_path = auxiliaryfunctions.get_deeplabcut_path() defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml") - model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path)) - if Shuffles is None: - Shuffles = range(1, num_shuffles + 1, 1) + if engine == Engine.PYTORCH: + model_path = dlcparent_path else: - Shuffles = [i for i in Shuffles if isinstance(i, int)] + model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path)) + + Shuffles = validate_shuffles(cfg, Shuffles, num_shuffles, userfeedback) # print(trainIndices,testIndices, Shuffles, augmenter_type,net_type) if trainIndices is None and testIndices is None: @@ -309,6 +388,11 @@ def create_multianimaltraining_dataset( test_inds = test_inds[test_inds != -1] splits.append((trainFraction, Shuffles[shuffle], (train_inds, test_inds))) + top_down = False + if engine == Engine.PYTORCH and net_type.startswith("top_down_"): + top_down = True + net_type = net_type[len("top_down_") :] + for trainFraction, shuffle, (trainIndices, testIndices) in splits: #################################################### # Generating data structure with labeled information & frame metadata (for deep cut) @@ -345,6 +429,15 @@ def create_multianimaltraining_dataset( testIndices, trainFraction, ) + metadata.update_metadata( + cfg=cfg, + train_fraction=trainFraction, + shuffle=shuffle, + engine=engine, + train_indices=trainIndices, + test_indices=testIndices, + overwrite=not userfeedback, + ) datafilename = datafilename.split(".mat")[0] + ".pickle" import pickle @@ -359,7 +452,10 @@ def create_multianimaltraining_dataset( ################################################################################# modelfoldername = auxiliaryfunctions.get_model_folder( - trainFraction, shuffle, cfg + trainFraction, + shuffle, + cfg, + engine=engine, ) auxiliaryfunctions.attempt_to_make_folder( Path(config).parents[0] / modelfoldername, recursive=True @@ -396,88 +492,126 @@ def create_multianimaltraining_dataset( ) ) - jointnames = [str(bpt) for bpt in multianimalbodyparts] - jointnames.extend([str(bpt) for bpt in uniquebodyparts]) - items2change = { - "dataset": datafilename, - "metadataset": metadatafilename, - "num_joints": len(multianimalbodyparts) - + len(uniquebodyparts), # cfg["uniquebodyparts"]), - "all_joints": [ - [i] for i in range(len(multianimalbodyparts) + len(uniquebodyparts)) - ], # cfg["uniquebodyparts"]))], - "all_joints_names": jointnames, - "init_weights": model_path, - "project_path": str(cfg["project_path"]), - "net_type": net_type, - "multi_stage": multi_stage, - "pairwise_loss_weight": 0.1, - "pafwidth": 20, - "partaffinityfield_graph": partaffinityfield_graph, - "partaffinityfield_predict": partaffinityfield_predict, - "weigh_only_present_joints": False, - "num_limbs": len(partaffinityfield_graph), - "dataset_type": dataset_type, - "optimizer": "adam", - "batch_size": 8, - "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]], - "save_iters": 10000, - "display_iters": 500, - "num_idchannel": ( - len(cfg["individuals"]) if cfg.get("identity", False) else 0 - ), - "crop_size": list(crop_size), - "crop_sampling": crop_sampling, - } + if engine == Engine.TF: + jointnames = [str(bpt) for bpt in multianimalbodyparts] + jointnames.extend([str(bpt) for bpt in uniquebodyparts]) + items2change = { + "dataset": datafilename, + "engine": engine.aliases[0], + "metadataset": metadatafilename, + "num_joints": len(multianimalbodyparts) + + len(uniquebodyparts), # cfg["uniquebodyparts"]), + "all_joints": [ + [i] + for i in range(len(multianimalbodyparts) + len(uniquebodyparts)) + ], # cfg["uniquebodyparts"]))], + "all_joints_names": jointnames, + "init_weights": str(model_path), + "project_path": str(cfg["project_path"]), + "net_type": net_type, + "multi_stage": multi_stage, + "pairwise_loss_weight": 0.1, + "pafwidth": 20, + "partaffinityfield_graph": partaffinityfield_graph, + "partaffinityfield_predict": partaffinityfield_predict, + "weigh_only_present_joints": False, + "num_limbs": len(partaffinityfield_graph), + "dataset_type": dataset_type, + "optimizer": "adam", + "batch_size": 8, + "multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]], + "save_iters": 10000, + "display_iters": 500, + "num_idchannel": ( + len(cfg["individuals"]) if cfg.get("identity", False) else 0 + ), + "crop_size": list(crop_size), + "crop_sampling": crop_sampling, + } + + trainingdata = MakeTrain_pose_yaml( + items2change, + path_train_config, + defaultconfigfile, + save=(engine == Engine.TF), + ) + keys2save = [ + "dataset", + "num_joints", + "all_joints", + "all_joints_names", + "net_type", + "multi_stage", + "init_weights", + "global_scale", + "location_refinement", + "locref_stdev", + "dataset_type", + "partaffinityfield_predict", + "pairwise_predict", + "partaffinityfield_graph", + "num_limbs", + "dataset_type", + "num_idchannel", + ] + + MakeTest_pose_yaml( + trainingdata, + keys2save, + path_test_config, + nmsradius=5.0, + minconfidence=0.01, + sigma=1, + locref_smooth=False, + ) # setting important def. values for inference + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.config.make_pose_config import ( + make_pytorch_pose_config, + make_pytorch_test_config, + ) + from deeplabcut.pose_estimation_pytorch.modelzoo.config import ( + make_super_animal_finetune_config, + ) - trainingdata = MakeTrain_pose_yaml( - items2change, path_train_config, defaultconfigfile - ) - keys2save = [ - "dataset", - "num_joints", - "all_joints", - "all_joints_names", - "net_type", - "multi_stage", - "init_weights", - "global_scale", - "location_refinement", - "locref_stdev", - "dataset_type", - "partaffinityfield_predict", - "pairwise_predict", - "partaffinityfield_graph", - "num_limbs", - "dataset_type", - "num_idchannel", - ] + # backwards compatibility with version 2.X + if net_type == "dlcrnet_ms5": + net_type = "dlcrnet_stride16_ms5" + + config_path = Path(path_train_config).with_name(engine.pose_cfg_name) + if weight_init is not None and weight_init.with_decoder: + pytorch_cfg = make_super_animal_finetune_config( + project_config=cfg, + pose_config_path=config_path, + model_name=net_type, + detector_name=detector_type, + weight_init=weight_init, + save=True, + ) + else: + pytorch_cfg = make_pytorch_pose_config( + project_config=cfg, + pose_config_path=config_path, + net_type=net_type, + top_down=top_down, + detector_type=detector_type, + weight_init=weight_init, + save=True, + ) - MakeTest_pose_yaml( - trainingdata, - keys2save, - path_test_config, - nmsradius=5.0, - minconfidence=0.01, - sigma=1, - locref_smooth=False, - ) # setting important def. values for inference + make_pytorch_test_config(pytorch_cfg, path_test_config, save=True) # Setting inference cfg file: - defaultinference_configfile = os.path.join( - dlcparent_path, "inference_cfg.yaml" - ) - items2change = { - "minimalnumberofconnections": int(len(cfg["multianimalbodyparts"]) / 2), - "topktoretain": len(cfg["individuals"]), - "withid": cfg.get("identity", False), - } - MakeInference_yaml( - items2change, path_inference_config, defaultinference_configfile + default_inf_path = Path(dlcparent_path) / "inference_cfg.yaml" + inf_updates = dict( + minimalnumberofconnections=int(len(cfg["multianimalbodyparts"]) / 2), + topktoretain=len(cfg["individuals"]), + withid=cfg.get("identity", False), ) + MakeInference_yaml(inf_updates, path_inference_config, default_inf_path) print( - "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!" + "The training dataset is successfully created. Use the function " + "'train_network' to start training. Happy training!" ) else: pass diff --git a/deeplabcut/generate_training_dataset/trainingsetmanipulation.py b/deeplabcut/generate_training_dataset/trainingsetmanipulation.py index 244ab2fcb5..08f7694633 100755 --- a/deeplabcut/generate_training_dataset/trainingsetmanipulation.py +++ b/deeplabcut/generate_training_dataset/trainingsetmanipulation.py @@ -8,6 +8,7 @@ # # Licensed under GNU Lesser General Public License v3.0 # +from __future__ import annotations import math import logging @@ -24,7 +25,10 @@ import pandas as pd import yaml -from deeplabcut.pose_estimation_tensorflow import training +import deeplabcut.compat as compat +import deeplabcut.generate_training_dataset.metadata as metadata +from deeplabcut.core.engine import Engine +from deeplabcut.core.weight_init import WeightInitialization from deeplabcut.utils import ( auxiliaryfunctions, conversioncode, @@ -32,8 +36,6 @@ auxfun_multianimal, ) from deeplabcut.utils.auxfun_videos import VideoReader -from deeplabcut.pose_estimation_tensorflow.config import load_config -from deeplabcut.modelzoo.utils import parse_available_supermodels def comparevideolistsanddatafolders(config): @@ -397,19 +399,26 @@ def ParseYaml(configfile): def MakeTrain_pose_yaml( - itemstochange, saveasconfigfile, defaultconfigfile, items2drop={} + itemstochange, + saveasconfigfile, + defaultconfigfile, + items2drop: dict | None = None, + save: bool = True, ): + if items2drop is None: + items2drop = {} + docs = ParseYaml(defaultconfigfile) for key in items2drop.keys(): - # print(key, "dropping?") if key in docs[0].keys(): docs[0].pop(key) for key in itemstochange.keys(): docs[0][key] = itemstochange[key] - with open(saveasconfigfile, "w") as f: - yaml.dump(docs[0], f) + if save: + with open(saveasconfigfile, "w") as f: + yaml.dump(docs[0], f) return docs[0] @@ -772,13 +781,16 @@ def create_training_dataset( num_shuffles=1, Shuffles=None, windows2linux=False, - userfeedback=False, + userfeedback=True, trainIndices=None, testIndices=None, net_type=None, + detector_type=None, augmenter_type=None, posecfg_template=None, superanimal_name="", + weight_init: WeightInitialization | None = None, + engine: Engine | None = None, ): """Creates a training dataset. @@ -797,7 +809,7 @@ def create_training_dataset( Shuffles: list[int], optional Alternatively the user can also give a list of shuffles. - userfeedback: bool, optional, default=False + userfeedback: bool, optional, default=True If ``False``, all requested train/test splits are created (no matter if they already exist). If you want to assure that previous splits etc. are not overwritten, set this to ``True`` and you will be asked for each split. @@ -810,41 +822,83 @@ def create_training_dataset( List of one or multiple lists containing test indexes. net_type: list, optional, default=None - Type of networks. Currently supported options are - - * ``resnet_50`` - * ``resnet_101`` - * ``resnet_152`` - * ``mobilenet_v2_1.0`` - * ``mobilenet_v2_0.75`` - * ``mobilenet_v2_0.5`` - * ``mobilenet_v2_0.35`` - * ``efficientnet-b0`` - * ``efficientnet-b1`` - * ``efficientnet-b2`` - * ``efficientnet-b3`` - * ``efficientnet-b4`` - * ``efficientnet-b5`` - * ``efficientnet-b6`` + Type of networks. The options available depend on which engine is used. + Currently supported options are: + TensorFlow + * ``resnet_50`` + * ``resnet_101`` + * ``resnet_152`` + * ``mobilenet_v2_1.0`` + * ``mobilenet_v2_0.75`` + * ``mobilenet_v2_0.5`` + * ``mobilenet_v2_0.35`` + * ``efficientnet-b0`` + * ``efficientnet-b1`` + * ``efficientnet-b2`` + * ``efficientnet-b3`` + * ``efficientnet-b4`` + * ``efficientnet-b5`` + * ``efficientnet-b6`` + PyTorch (call ``deeplabcut.pose_estimation_pytorch.available_models()`` for + a complete list) + * ``resnet_50`` + * ``resnet_101`` + * ``hrnet_w18`` + * ``hrnet_w32`` + * ``hrnet_w48`` + * ``dekr_w18`` + * ``dekr_w32`` + * ``dekr_w48`` + * ``top_down_resnet_50`` + * ``top_down_resnet_101`` + * ``top_down_hrnet_w18`` + * ``top_down_hrnet_w32`` + * ``top_down_hrnet_w48`` + * ``animaltokenpose_base`` + + detector_type: string, optional, default=None + Only for the PyTorch engine. + When passing creating shuffles for top-down models, you can specify which + detector you want. If the detector_type is None, the ```ssdlite``` will be used. + The list of all available detectors can be obtained by calling + ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported options: + * ``ssdlite`` + * ``fasterrcnn_mobilenet_v3_large_fpn`` + * ``fasterrcnn_resnet50_fpn_v2`` augmenter_type: string, optional, default=None - Type of augmenter. Currently supported augmenters are - - * ``default`` - * ``scalecrop`` - * ``imgaug`` - * ``tensorpack`` - * ``deterministic`` + Type of augmenter. The options available depend on which engine is used. + Currently supported options are: + TensorFlow + * ``default`` + * ``scalecrop`` + * ``imgaug`` + * ``tensorpack`` + * ``deterministic`` + PyTorch + * ``albumentations`` posecfg_template: string, optional, default=None + Only for the TensorFlow engine. Path to a ``pose_cfg.yaml`` file to use as a template for generating the new one for the current iteration. Useful if you would like to start with the same parameters a previous training iteration. None uses the default ``pose_cfg.yaml``. superanimal_name: string, optional, default="" - Specify the superanimal name is transfer learning with superanimal is desired. This makes sure the pose config template uses superanimal configs as template + Only for the TensorFlow engine. For the PyTorch engine, use the ``weight_init`` + parameter. + Specify the superanimal name is transfer learning with superanimal is desired. + This makes sure the pose config template uses superanimal configs as template. + weight_init: WeightInitialisation, optional, default=None + PyTorch engine only. Specify how model weights should be initialized. The + default mode uses transfer learning from ImageNet weights. + + engine: Engine, optional + Whether to create a pose config for a Tensorflow or PyTorch model. Defaults to + the value specified in the project configuration file. If no engine is specified + for the project, defaults to ``deeplabcut.compat.DEFAULT_ENGINE``. Returns ------- @@ -890,6 +944,7 @@ def create_training_dataset( dlc_root_path = auxiliaryfunctions.get_deeplabcut_path() if superanimal_name != "": + # FIXME(niels): this is deprecated supermodels = parse_available_supermodels() posecfg_template = os.path.join( dlc_root_path, @@ -922,12 +977,19 @@ def create_training_dataset( num_shuffles, Shuffles, net_type=net_type, + detector_type=detector_type, trainIndices=trainIndices, testIndices=testIndices, + userfeedback=userfeedback, + engine=engine, + weight_init=weight_init, ) else: scorer = cfg["scorer"] project_path = cfg["project_path"] + if engine is None: + engine = compat.get_project_engine(cfg) + # Create path for training sets & store data there trainingsetfolder = auxiliaryfunctions.get_training_set_folder( cfg @@ -936,6 +998,11 @@ def create_training_dataset( Path(os.path.join(project_path, str(trainingsetfolder))), recursive=True ) + # Create the trainset metadata file, if it doesn't yet exist + if not metadata.TrainingDatasetMetadata.path(cfg).exists(): + trainset_metadata = metadata.TrainingDatasetMetadata.create(cfg) + trainset_metadata.save() + Data = merge_annotateddatasets( cfg, Path(os.path.join(project_path, trainingsetfolder)), @@ -947,6 +1014,8 @@ def create_training_dataset( # loading & linking pretrained models if net_type is None: # loading & linking pretrained models net_type = cfg.get("default_net_type", "resnet_50") + elif engine == Engine.PYTORCH: + pass else: if ( "resnet" in net_type @@ -958,20 +1027,40 @@ def create_training_dataset( else: raise ValueError("Invalid network type:", net_type) + top_down = False + if engine == Engine.PYTORCH: + if net_type.startswith("top_down_"): + top_down = True + net_type = net_type[len("top_down_") :] + + augmenters = compat.get_available_aug_methods(engine) + default_augmenter = augmenters[0] if augmenter_type is None: - augmenter_type = cfg.get("default_augmenter", "imgaug") + augmenter_type = cfg.get("default_augmenter", default_augmenter) + if augmenter_type is None: # this could be in config.yaml for old projects! # updating variable if null/None! #backwardscompatability - auxiliaryfunctions.edit_config(config, {"default_augmenter": "imgaug"}) - augmenter_type = "imgaug" - elif augmenter_type not in [ - "default", - "scalecrop", - "imgaug", - "tensorpack", - "deterministic", - ]: - raise ValueError("Invalid augmenter type:", augmenter_type) + augmenter_type = default_augmenter + auxiliaryfunctions.edit_config( + config, {"default_augmenter": augmenter_type} + ) + elif augmenter_type not in augmenters: + # as the default augmenter might not be available for the given engine + augmenter_type = default_augmenter + logging.info( + f"Default augmenter {augmenter_type} not available for engine " + f"{engine}: using {default_augmenter} instead" + ) + + if augmenter_type not in augmenters: + if engine != Engine.PYTORCH: + raise ValueError( + f"Invalid augmenter type: {augmenter_type} (available: for " + f"engine={engine}: {augmenters})" + ) + + logging.info(f"Switching augmentation to {default_augmenter} for PyTorch") + augmenter_type = default_augmenter if posecfg_template: if net_type != prior_cfg["net_type"]: @@ -989,12 +1078,13 @@ def create_training_dataset( defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml") elif posecfg_template: defaultconfigfile = posecfg_template - model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path)) - if Shuffles is None: - Shuffles = range(1, num_shuffles + 1) + if engine == Engine.PYTORCH: + model_path = dlcparent_path else: - Shuffles = [i for i in Shuffles if isinstance(i, int)] + model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path)) + + Shuffles = validate_shuffles(cfg, Shuffles, num_shuffles, userfeedback) # print(trainIndices,testIndices, Shuffles, augmenter_type,net_type) if trainIndices is None and testIndices is None: @@ -1032,15 +1122,16 @@ def create_training_dataset( (trainFraction, Shuffles[shuffle], (train_inds, test_inds)) ) - bodyparts = cfg["bodyparts"] + bodyparts = auxiliaryfunctions.get_bodyparts(cfg) nbodyparts = len(bodyparts) for trainFraction, shuffle, (trainIndices, testIndices) in splits: if len(trainIndices) > 0: if userfeedback: - trainposeconfigfile, _, _ = training.return_train_network_path( + trainposeconfigfile, _, _ = compat.return_train_network_path( config, shuffle=shuffle, trainingsetindex=cfg["TrainingFraction"].index(trainFraction), + engine=engine, ) if trainposeconfigfile.is_file(): askuser = input( @@ -1087,13 +1178,25 @@ def create_training_dataset( testIndices, trainFraction, ) + metadata.update_metadata( + cfg=cfg, + train_fraction=trainFraction, + shuffle=shuffle, + engine=engine, + train_indices=trainIndices, + test_indices=testIndices, + overwrite=not userfeedback, + ) ################################################################################ # Creating file structure for training & # Test files as well as pose_yaml files (containing training and testing information) ################################################################################# modelfoldername = auxiliaryfunctions.get_model_folder( - trainFraction, shuffle, cfg + trainFraction, + shuffle, + cfg, + engine=engine, ) auxiliaryfunctions.attempt_to_make_folder( Path(config).parents[0] / modelfoldername, recursive=True @@ -1110,7 +1213,7 @@ def create_training_dataset( cfg["project_path"], Path(modelfoldername), "train", - "pose_cfg.yaml", + engine.pose_cfg_name, ) ) path_test_config = str( @@ -1121,68 +1224,204 @@ def create_training_dataset( "pose_cfg.yaml", ) ) - # str(cfg['proj_path']+'/'+Path(modelfoldername) / 'test' / 'pose_cfg.yaml') - items2change = { - "dataset": datafilename, - "metadataset": metadatafilename, - "num_joints": len(bodyparts), - "all_joints": [[i] for i in range(len(bodyparts))], - "all_joints_names": [str(bpt) for bpt in bodyparts], - "init_weights": model_path, - "project_path": str(cfg["project_path"]), - "net_type": net_type, - "dataset_type": augmenter_type, - } - - items2drop = {} - if augmenter_type == "scalecrop": - # these values are dropped as scalecrop - # doesn't have rotation implemented - items2drop = {"rotation": 0, "rotratio": 0.0} - # Also drop maDLC smart cropping augmentation parameters - for key in ["pre_resize", "crop_size", "max_shift", "crop_sampling"]: - items2drop[key] = None - - trainingdata = MakeTrain_pose_yaml( - items2change, path_train_config, defaultconfigfile, items2drop - ) + if engine == Engine.TF: + items2change = { + "dataset": datafilename, + "engine": engine.aliases[0], + "metadataset": metadatafilename, + "num_joints": len(bodyparts), + "all_joints": [[i] for i in range(len(bodyparts))], + "all_joints_names": [str(bpt) for bpt in bodyparts], + "init_weights": model_path, + "project_path": str(cfg["project_path"]), + "net_type": net_type, + "dataset_type": augmenter_type, + } + + items2drop = {} + if augmenter_type == "scalecrop": + # these values are dropped as scalecrop + # doesn't have rotation implemented + items2drop = {"rotation": 0, "rotratio": 0.0} + # Also drop maDLC smart cropping augmentation parameters + for key in [ + "pre_resize", + "crop_size", + "max_shift", + "crop_sampling", + ]: + items2drop[key] = None + + trainingdata = MakeTrain_pose_yaml( + items2change, + path_train_config, + defaultconfigfile, + items2drop, + save=(engine == Engine.TF), + ) - keys2save = [ - "dataset", - "num_joints", - "all_joints", - "all_joints_names", - "net_type", - "init_weights", - "global_scale", - "location_refinement", - "locref_stdev", - ] - MakeTest_pose_yaml(trainingdata, keys2save, path_test_config) - print( - "The training dataset is successfully created. Use the function 'train_network' to start training. Happy training!" - ) + keys2save = [ + "dataset", + "num_joints", + "all_joints", + "all_joints_names", + "net_type", + "init_weights", + "global_scale", + "location_refinement", + "locref_stdev", + ] + MakeTest_pose_yaml(trainingdata, keys2save, path_test_config) + print( + "The training dataset is successfully created. Use the function" + "'train_network' to start training. Happy training!" + ) + elif engine == Engine.PYTORCH: + from deeplabcut.pose_estimation_pytorch.config.make_pose_config import ( + make_pytorch_pose_config, + make_pytorch_test_config, + ) + from deeplabcut.pose_estimation_pytorch.modelzoo.config import ( + make_super_animal_finetune_config, + ) + + if weight_init is not None and weight_init.with_decoder: + pytorch_cfg = make_super_animal_finetune_config( + project_config=cfg, + pose_config_path=path_train_config, + model_name=net_type, + detector_name=detector_type, + weight_init=weight_init, + save=True, + ) + else: + pytorch_cfg = make_pytorch_pose_config( + project_config=cfg, + pose_config_path=path_train_config, + net_type=net_type, + top_down=top_down, + detector_type=detector_type, + weight_init=weight_init, + save=True, + ) + + make_pytorch_test_config(pytorch_cfg, path_test_config, save=True) return splits def get_largestshuffle_index(config): """Returns the largest shuffle for all dlc-models in the current iteration.""" - cfg = auxiliaryfunctions.read_config(config) - project_path = cfg["project_path"] - iterate = "iteration-" + str(cfg["iteration"]) - dlc_model_path = os.path.join(project_path, "dlc-models", iterate) - if os.path.isdir(dlc_model_path): - models = os.listdir(dlc_model_path) - # sort the model directories - models.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) - - # get the shuffle index and offset by 1. - max_shuffle_index = int(models[-1].split("shuffle")[-1]) + 1 + shuffle_indices = get_existing_shuffle_indices(config) + if len(shuffle_indices) > 0: + return shuffle_indices[-1] + + return None + + +def get_existing_shuffle_indices( + cfg: dict | str | Path, + train_fraction: float | None = None, + engine: Engine | None = None, +) -> List[int]: + """ + Args: + cfg: The content of a project configuration file, or the path to the project + configuration file. + train_fraction: If defined, only get the indices of shuffles with this train + fraction. + engine: If specified, returns only the shuffle indices that were created with + the given engine. Can only be used when train_fraction is also defined. + + Returns: + the indices of existing shuffles for this iteration of the project, sorted by + ascending index + """ + + def is_valid_data_stem(stem: str) -> bool: + if len(stem) == 0: + return False + suffix = stem.split("_")[-1] + if len(suffix) == 0: + return False + info = suffix.split("shuffle") + if len(info) != 2: + return False + train_frac, idx = info + return ( + train_frac.isdigit() + and idx.isdigit() + and (train_fraction is None or int(train_frac) == int(100 * train_fraction)) + ) + + if isinstance(cfg, (str, Path)): + cfg = auxiliaryfunctions.read_config(cfg) + + project = Path(cfg["project_path"]) + trainset_folder = project / auxiliaryfunctions.get_training_set_folder(cfg) + if not trainset_folder.exists(): + return [] + + shuffle_indices = [ + int(p.stem.split("shuffle")[-1]) + for p in trainset_folder.iterdir() + if ( + p.stem.startswith("Documentation_data") + and p.suffix == ".pickle" + and is_valid_data_stem(p.stem) + ) + ] + if engine is not None: + if train_fraction is None: + raise ValueError( + f"Must select {train_fraction} to filter shuffles by engine" + ) + + shuffle_indices = [ + idx + for idx in shuffle_indices + if ( + project + / auxiliaryfunctions.get_model_folder( + trainFraction=train_fraction, + shuffle=idx, + cfg=cfg, + engine=engine, + ) + ).exists() + ] + + return sorted(shuffle_indices) + + +def validate_shuffles( + cfg: dict, + shuffles: list[int] | None, + num_shuffles: int | None, + userfeedback: bool, +) -> list[int]: + existing_shuffles = get_existing_shuffle_indices(cfg) + if shuffles is None: + first_index = 1 + if len(existing_shuffles) > 0: + first_index = existing_shuffles[-1] + 1 + + shuffles = range(first_index, num_shuffles + first_index) else: - max_shuffle_index = 0 + shuffles = [i for i in shuffles if isinstance(i, int)] + for shuffle_idx in shuffles: + if userfeedback and shuffle_idx in existing_shuffles: + raise ValueError( + f"Cannot create shuffle {shuffle_idx} as it already exists - " + f"you must either create the dataset with `userfeedback=False` " + f"or delete the shuffle with index {shuffle_idx} manually (in " + f"`dlc-models`/`dlc-models-pytorch` and in the " + f"`training-datasets` folder) if you want to create a new " + f"shuffle with that index. You can otherwise create a shuffle " + f"with a new index. Existing indices are {existing_shuffles}." + ) - return max_shuffle_index + return shuffles def create_training_model_comparison( @@ -1301,7 +1540,7 @@ def create_training_model_comparison( else: pass - largestshuffleindex = get_largestshuffle_index(config) + largestshuffleindex = get_existing_shuffle_indices(cfg)[-1] + 1 shuffle_list = [] for shuffle in range(num_shuffles): @@ -1342,3 +1581,183 @@ def create_training_model_comparison( logger.info(log_info) return shuffle_list + + +def create_training_dataset_from_existing_split( + config: str, + from_shuffle: int, + from_trainsetindex: int = 0, + num_shuffles: int = 1, + shuffles: list[int] | None = None, + userfeedback: bool = True, + net_type: str | None = None, + detector_type: str | None = None, + augmenter_type: str | None = None, + posecfg_template: dict | None = None, + superanimal_name: str = "", + weight_init: WeightInitialization | None = None, + engine: Engine | None = None, +) -> None | list[int]: + """ + Labels from all the extracted frames are merged into a single .h5 file. + Only the videos included in the config file are used to create this dataset. + + Args: + config: Full path of the ``config.yaml`` file as a string. + + from_shuffle: The index of the shuffle from which to copy the train/test split. + + from_trainsetindex: The trainset index of the shuffle from which to use the data + split. Default is 0. + + num_shuffles: Number of shuffles of training dataset to create, used if + ``shuffles`` is None. + + shuffles: If defined, ``num_shuffles`` is ignored and a shuffle is created for + each index given in the list. + + userfeedback: If ``False``, all requested train/test splits are created (no + matter if they already exist). If you want to assure that previous splits + etc. are not overwritten, set this to ``True`` and you will be asked for + each existing split if you want to overwrite it. + + net_type: The type of network to create the shuffle for. Currently supported + options for engine=Engine.TF are: + * ``resnet_50`` + * ``resnet_101`` + * ``resnet_152`` + * ``mobilenet_v2_1.0`` + * ``mobilenet_v2_0.75`` + * ``mobilenet_v2_0.5`` + * ``mobilenet_v2_0.35`` + * ``efficientnet-b0`` + * ``efficientnet-b1`` + * ``efficientnet-b2`` + * ``efficientnet-b3`` + * ``efficientnet-b4`` + * ``efficientnet-b5`` + * ``efficientnet-b6`` + Currently supported options for engine=Engine.TF can be obtained by calling + ``deeplabcut.pose_estimation_pytorch.available_models()``. + + detector_type: string, optional, default=None + Only for the PyTorch engine. + When passing creating shuffles for top-down models, you can specify which + detector you want. If the detector_type is None, the ```ssdlite``` will be + used. The list of all available detectors can be obtained by calling + ``deeplabcut.pose_estimation_pytorch.available_detectors()``. Supported + options: + * ``ssdlite`` + * ``fasterrcnn_mobilenet_v3_large_fpn`` + * ``fasterrcnn_resnet50_fpn_v2`` + + augmenter_type: Type of augmenter. Currently supported augmenters for + engine=Engine.TF are + * ``default`` + * ``scalecrop`` + * ``imgaug`` + * ``tensorpack`` + * ``deterministic`` + The only supported augmenter for Engine.PYTORCH is ``albumentations``. + + posecfg_template: Only for Engine.TF. Path to a ``pose_cfg.yaml`` file to use as + a template for generating the new one for the current iteration. Useful if + you would like to start with the same parameters a previous training + iteration. None uses the default ``pose_cfg.yaml``. + + superanimal_name: Specify the superanimal name is transfer learning with + superanimal is desired. This makes sure the pose config template uses + superanimal configs as template. + + weight_init: Only for Engine.PYTORCH. Specify how model weights should be + initialized. The default mode uses transfer learning from ImageNet weights. + + engine: Whether to create a pose config for a Tensorflow or PyTorch model. + Defaults to the value specified in the project configuration file. If no + engine is specified for the project, defaults to + ``deeplabcut.compat.DEFAULT_ENGINE``. + + Returns: + If training dataset was successfully created, a list of tuples is returned. + The first two elements in each tuple represent the training fraction and the + shuffle value. The last two elements in each tuple are arrays of integers + representing the training and test indices. + + Returns None if training dataset could not be created. + + Raises: + ValueError: If the shuffle from which to copy the data split doesn't exist. + """ + cfg = auxiliaryfunctions.read_config(config) + trainset_meta_path = metadata.TrainingDatasetMetadata.path(cfg) + if not trainset_meta_path.exists(): + meta = metadata.TrainingDatasetMetadata.create(cfg) + meta.save() + else: + meta = metadata.TrainingDatasetMetadata.load(cfg, load_splits=False) + + shuffle = meta.get(trainset_index=from_trainsetindex, index=from_shuffle) + shuffle = shuffle.load_split(cfg, trainset_path=trainset_meta_path.parent) + + num_copies = num_shuffles + if shuffles is not None: + num_copies = len(shuffles) + + # pad the train and test indices with -1s so the training fraction is exact + train_idx = list(shuffle.split.train_indices) + test_idx = list(shuffle.split.test_indices) + n_train, n_test = len(train_idx), len(test_idx) + + train_fraction = round(cfg["TrainingFraction"][from_trainsetindex], 2) + if round(n_train / (n_train + n_test), 2) != train_fraction: + train_padding, test_padding = _compute_padding(train_fraction, n_train, n_test) + train_idx = train_idx + (train_padding * [-1]) + test_idx = test_idx + (test_padding * [-1]) + + return create_training_dataset( + config=config, + num_shuffles=num_shuffles, + Shuffles=shuffles, + userfeedback=userfeedback, + trainIndices=[train_idx for _ in range(num_copies)], + testIndices=[test_idx for _ in range(num_copies)], + net_type=net_type, + detector_type=detector_type, + augmenter_type=augmenter_type, + posecfg_template=posecfg_template, + superanimal_name=superanimal_name, + weight_init=weight_init, + engine=engine, + ) + + +def _compute_padding( + train_fraction: float, + num_train: int, + num_test: int, +) -> tuple[int, int]: + """ + Computes the amount of padding to add to train/test indices such that + train_fraction = num_train / (num_train + num_test). + + Returns: + the number of padding indices to add to the train indices + the number of padding indices to add to the test indices + """ + if train_fraction <= 0 or train_fraction >= 1: + raise ValueError( + f"The training fraction must satisfy 0 < TrainingFraction < 1, but " + f"{train_fraction} was found" + ) + + base_images = 100 + train_step = int(round(round(train_fraction, 2) * base_images)) + test_step = base_images - train_step + + tgt_train = train_step + tgt_test = test_step + while tgt_train < num_train or tgt_test < num_test: + tgt_train += train_step + tgt_test += test_step + + return (tgt_train - num_train), (tgt_test - num_test) diff --git a/deeplabcut/gui/components.py b/deeplabcut/gui/components.py index 6fc9cfa7f5..5edc7712a4 100644 --- a/deeplabcut/gui/components.py +++ b/deeplabcut/gui/components.py @@ -89,6 +89,11 @@ def __init__( self.itemSelectionChanged.connect(self.update_selected_bodyparts) + def refresh(self): + self.clear() + self.addItems(self.root.all_bodyparts) + self.update_selected_bodyparts() + def update_selected_bodyparts(self): self.selected_bodyparts = [item.text() for item in self.selectedItems()] self.root.logger.info(f"Selected bodyparts:\n\t{self.selected_bodyparts}") @@ -179,6 +184,81 @@ def clear_selected_videos(self): self.root.logger.info(f"Cleared selected videos") +class SnapshotSelectionWidget(QtWidgets.QWidget): + def __init__( + self, + root: QtWidgets.QMainWindow, + parent: QtWidgets.QWidget, + margins: tuple, + select_button_text: str, + ): + super(SnapshotSelectionWidget, self).__init__(parent) + + self.root = root + self.parent = parent + + self.selected_snapshot = None + + self._init_layout(margins, select_button_text) + + def _init_layout(self, margins, select_button_text): + layout = _create_horizontal_layout(margins=margins) + + # Select videos + self.select_snapshot_button = QtWidgets.QPushButton(select_button_text) + self.select_snapshot_button.setMaximumWidth(200) + self.select_snapshot_button.clicked.connect(self.select_snapshot) + + # Selected snapshot text + self.selected_snapshot_text = QtWidgets.QLabel( + "" + ) # updated when snapshot is selected + + # Clear snapshot selection + self.clear_snapshot_button = QtWidgets.QPushButton("Clear selection") + self.clear_snapshot_button.clicked.connect(self.clear_selected_snapshot) + self.clear_snapshot_button.hide() + + layout.addWidget(self.select_snapshot_button) + layout.addWidget(self.selected_snapshot_text) + layout.addWidget(self.clear_snapshot_button, alignment=Qt.AlignRight) + + self.setLayout(layout) + + def _update_selected_snapshot_display(self): + if self.selected_snapshot is None: + self.selected_snapshot_text.setText("") + self.clear_snapshot_button.hide() + else: + self.selected_snapshot_text.setText( + f"{os.path.basename(self.selected_snapshot)}" + ) + self.clear_snapshot_button.show() + + def select_snapshot(self): + # Create a filter string with both lowercase and uppercase extensions + snapshot_types = ["*.pt", "*.PT"] + snapshot_files = f"Snapshots ({' '.join(snapshot_types)})" + + directory_to_open = self.root.models_folder + + selected_snapshot, _ = QtWidgets.QFileDialog.getOpenFileName( + self, + "Select snapshot to start training from", + directory_to_open, + snapshot_files, + ) + # When Canceling a file selection, Qt returns an empty string as selected file + if selected_snapshot: + self.selected_snapshot = os.path.abspath(selected_snapshot) + + self._update_selected_snapshot_display() + + def clear_selected_snapshot(self): + self.selected_snapshot = None + self._update_selected_snapshot_display() + + class TrainingSetSpinBox(QtWidgets.QSpinBox): def __init__(self, root, parent): super(TrainingSetSpinBox, self).__init__(parent) @@ -201,6 +281,12 @@ def __init__(self, root, parent): self.setMaximum(10_000) self.setValue(self.root.shuffle_value) self.valueChanged.connect(self.root.update_shuffle) + self.root.shuffle_change.connect(self.update_shuffle) + + @Slot(int) + def update_shuffle(self, new_shuffle: int): + if new_shuffle != self.value(): + self.setValue(new_shuffle) class DefaultTab(QtWidgets.QWidget): diff --git a/deeplabcut/gui/displays/__init__.py b/deeplabcut/gui/displays/__init__.py new file mode 100644 index 0000000000..f511e6184a --- /dev/null +++ b/deeplabcut/gui/displays/__init__.py @@ -0,0 +1,12 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# + + diff --git a/deeplabcut/gui/displays/selected_shuffle_display.py b/deeplabcut/gui/displays/selected_shuffle_display.py new file mode 100644 index 0000000000..13e6db1094 --- /dev/null +++ b/deeplabcut/gui/displays/selected_shuffle_display.py @@ -0,0 +1,130 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Module to display information about the selected shuffle in the GUI""" +from __future__ import annotations +from pathlib import Path + +import PySide6.QtCore as QtCore +from PySide6 import QtWidgets + +from deeplabcut.core.engine import Engine +from deeplabcut.utils import auxiliaryfunctions + + +class SelectedShuffleDisplay(QtWidgets.QWidget): + """A widget displaying information about the selected shuffle""" + pose_cfg_signal = QtCore.Signal(dict) + + def __init__(self, root, row_margin: int = 25): + super().__init__() + self.root = root + + self._row_margin = row_margin + + self._current_index: int | None = None + self._engine: Engine | None = None + self._is_top_down: bool = False + self._net_type: str | None = None + self._pose_cfg: dict | None = None + + self._label = QtWidgets.QLabel("Shuffle info:") + self._label.setStyleSheet(f"margin: 0px 0px {self._row_margin}px 0px") + layout = QtWidgets.QHBoxLayout() + layout.addWidget(self._label) + self.setLayout(layout) + + # initialize the display + self._update_display(self.root.shuffle_value) + + # update the display when the shuffle or selected engine changes, or when a new + # shuffle has been created + self.root.shuffle_change.connect(self._update_display) + self.root.engine_change.connect(self._update_display) + self.root.shuffle_created.connect(self._update_display) + + @property + def pose_cfg(self) -> dict | None: + return self._pose_cfg + + @pose_cfg.setter + def pose_cfg(self, value: dict | None) -> None: + self._pose_cfg = value + self.pose_cfg_signal.emit(self._pose_cfg) + + @QtCore.Slot(int) + def _update_display(self, new_index: int) -> None: + self._current_index = new_index + + try: + pose_cfg_path = Path(self.root.pose_cfg_path) + except ValueError as err: + self._set_text_error( + f"Failed to read shuffle {self._current_index} - check that it exists!" + ) + return + except ModuleNotFoundError as err: + # Loading a TF shuffle but TF is not installed + self._set_text_error( + f"Failed to read shuffle {self._current_index} due to error `{err}`.\n" + "If the error is `ModuleNotFoundError: No module named 'tensorflow'`, " + f"this is because\nshuffle {self._current_index} uses the tensorflow " + " engine, but TensorFlow is not installed in your environment.\n" + "Ignore this error if you'll just train PyTorch models. To train " + "TensorFlow models, install it with \n" + " Windows/Linux: pip install 'deeplabcut[tf]'\n" + " Apple Silicon: pip install 'deeplabcut[apple_mchips]'" + ) + return + + if not pose_cfg_path.exists(): + self._set_text_error( + f"The model configuration file {pose_cfg_path} was not created" + ) + return + + self._read_pose_config(pose_cfg_path) + self._set_text() + + def _set_text(self) -> None: + engine_str = "None" + if self._engine is not None: + engine_str = self._engine.aliases[0] + + text = f"net type: {self._net_type} | engine: {engine_str}" + if self._engine == Engine.PYTORCH and self._is_top_down: + text += f" | top-down" + + style = f"margin: 0px 0px {self._row_margin}px 0px;" + if self._engine != self.root.engine: + warning = "Change the selected Engine in the top-right to use this shuffle!" + text = warning + " | " + text + style += " color: orange;" + + self._label.setStyleSheet(style) + self._label.setText(text) + + def _set_text_error(self, error: str) -> None: + self._label.setText(error) + style = f"margin: 0px 0px {self._row_margin}px 0px; color: orange;" + self._label.setStyleSheet(style) + self.pose_cfg = None + + def _read_pose_config(self, pose_cfg_path: Path) -> None: + pose_cfg = auxiliaryfunctions.read_plainconfig(str(pose_cfg_path)) + + self._engine = ( + Engine.PYTORCH if "pytorch" in pose_cfg_path.stem.lower() else Engine.TF + ) + self._net_type = pose_cfg.get("net_type", "UNKNOWN") + self._is_top_down = ( + self._engine == Engine.PYTORCH and pose_cfg.get("method").lower() == "td" + ) + self.pose_cfg = pose_cfg diff --git a/deeplabcut/gui/displays/shuffle_metadata_viewer.py b/deeplabcut/gui/displays/shuffle_metadata_viewer.py new file mode 100644 index 0000000000..b18aef85b4 --- /dev/null +++ b/deeplabcut/gui/displays/shuffle_metadata_viewer.py @@ -0,0 +1,63 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +"""Widget to display existing shuffles""" +from __future__ import annotations + +from PySide6 import QtWidgets +from PySide6.QtCore import Qt + +import deeplabcut.generate_training_dataset.metadata as metadata + + +class ShuffleMetadataViewer(QtWidgets.QDialog): + """Viewer for shuffle metadata""" + + def __init__(self, root: QtWidgets.QMainWindow, parent: QtWidgets.QWidget): + super().__init__(parent) + self.root = root + self.parent = parent + self.file_content = _load_metadata(self.root.cfg) + + self.setWindowTitle("Existing Shuffles: Metadata") + self.setMinimumWidth(400) + self.setMinimumHeight(400) + + scroll = QtWidgets.QScrollArea() + scroll.setWidgetResizable(True) + + inner_layout = QtWidgets.QVBoxLayout() + inner_layout.setAlignment(Qt.AlignLeft | Qt.AlignTop) + inner_layout.setSpacing(0) + inner_layout.setContentsMargins(0, 0, 0, 0) + + for line in self.file_content: + + inner_layout.addWidget(QtWidgets.QLabel(line)) + + inner = QtWidgets.QFrame(scroll) + inner.setLayout(inner_layout) + scroll.setWidget(inner) + + layout = QtWidgets.QVBoxLayout() + layout.addWidget(scroll) + self.setLayout(layout) + + +def _load_metadata(cfg: dict) -> list[str]: + metadata_path = metadata.TrainingDatasetMetadata.path(cfg) + if not metadata_path.exists(): + trainset_meta = metadata.TrainingDatasetMetadata.create(cfg) + trainset_meta.save() + + with open(metadata_path, "r") as file: + raw_metadata = file.read() + + return raw_metadata.split("\n") diff --git a/deeplabcut/gui/dlc_params.py b/deeplabcut/gui/dlc_params.py index 6563682951..fce5d15a52 100644 --- a/deeplabcut/gui/dlc_params.py +++ b/deeplabcut/gui/dlc_params.py @@ -31,8 +31,6 @@ class DLCParams: "efficientnet-b6", ] - IMAGE_AUGMENTERS = ["default", "tensorpack", "imgaug"] - FRAME_EXTRACTION_ALGORITHMS = ["kmeans", "uniform"] OUTLIER_EXTRACTION_ALGORITHMS = ["jump", "fitting", "uncertain", "manual"] diff --git a/deeplabcut/gui/media/dlc-pt.png b/deeplabcut/gui/media/dlc-pt.png new file mode 100644 index 0000000000..d0ac99c187 Binary files /dev/null and b/deeplabcut/gui/media/dlc-pt.png differ diff --git a/deeplabcut/gui/media/dlc-tf.png b/deeplabcut/gui/media/dlc-tf.png new file mode 100644 index 0000000000..79d06f0528 Binary files /dev/null and b/deeplabcut/gui/media/dlc-tf.png differ diff --git a/deeplabcut/gui/tabs/analyze_videos.py b/deeplabcut/gui/tabs/analyze_videos.py index e608839701..60971f4c3e 100644 --- a/deeplabcut/gui/tabs/analyze_videos.py +++ b/deeplabcut/gui/tabs/analyze_videos.py @@ -27,6 +27,7 @@ import deeplabcut from deeplabcut.utils.auxiliaryfunctions import edit_config +from deeplabcut.utils import auxfun_multianimal class AnalyzeVideos(DefaultTab): @@ -198,36 +199,36 @@ def _generate_layout_multianimal(self, layout): layout.addLayout(tmp_layout) def update_create_video_detections(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Create video with all detections {s}") def update_assemble_with_ID_only(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Assembly with ID only {s}") def update_calibrate_assembly(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Assembly calibration {s}") def update_tracker_type(self, method): self.root.logger.info(f"Using {method.upper()} tracker") def update_csv_choice(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Save results as CSV {s}") def update_filter_choice(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Filtering predictions {s}") def update_showfigs_choice(self, state): - if state == Qt.Checked: + if Qt.CheckState(state) == Qt.Checked: self.root.logger.info("Plots will show as pop ups.") else: self.root.logger.info("Plots will not show up.") def update_crop_choice(self, state): - if state == Qt.Checked: + if Qt.CheckState(state) == Qt.Checked: self.root.logger.info("Dynamic bodypart cropping ENABLED.") self.dynamic_cropping = True else: @@ -235,7 +236,8 @@ def update_crop_choice(self, state): self.dynamic_cropping = False def update_plot_trajectory_choice(self, state): - if state == Qt.Checked: + if Qt.CheckState(state) == Qt.Checked: + self.bodyparts_list_widget.refresh() self.bodyparts_list_widget.show() self.bodyparts_list_widget.setEnabled(True) self.show_trajectory_plots.setEnabled(True) @@ -263,12 +265,8 @@ def analyze_videos(self): videotype = self.video_selection_widget.videotype_widget.currentText() if self.root.is_multianimal: - calibrate_assembly = ( - self.calibrate_assembly_checkbox.isChecked() - ) - assemble_with_ID_only = ( - self.assemble_with_ID_only_checkbox.isChecked() - ) + calibrate_assembly = self.calibrate_assembly_checkbox.isChecked() + assemble_with_ID_only = self.assemble_with_ID_only_checkbox.isChecked() track_method = self.tracker_type_widget.currentText() edit_config(self.root.config, {"default_track_method": track_method}) num_animals_in_videos = self.num_animals_in_videos.value() @@ -339,6 +337,7 @@ def run_enabled(self): shuffle=shuffle, ) + track_method = auxfun_multianimal.get_track_method(self.root.cfg) if filter_data: deeplabcut.filterpredictions( config, @@ -348,6 +347,7 @@ def run_enabled(self): filtertype="median", windowlength=5, save_as_csv=save_as_csv, + track_method=track_method, ) if self.plot_trajectories.isChecked(): @@ -355,7 +355,6 @@ def run_enabled(self): self.root.logger.debug( f"Selected body parts for plot_trajectories: {bdpts}" ) - showfig = self.show_trajectory_plots.isChecked() deeplabcut.plot_trajectories( config, videos=videos, @@ -363,7 +362,8 @@ def run_enabled(self): videotype=videotype, shuffle=shuffle, filtered=filter_data, - showfigures=showfig, + showfigures=self.show_trajectory_plots.isChecked(), + track_method=track_method, ) if self.root.is_multianimal and save_as_csv: diff --git a/deeplabcut/gui/tabs/create_project.py b/deeplabcut/gui/tabs/create_project.py index 6eaf66fa38..95331613b2 100644 --- a/deeplabcut/gui/tabs/create_project.py +++ b/deeplabcut/gui/tabs/create_project.py @@ -4,24 +4,187 @@ # https://github.com/DeepLabCut/DeepLabCut # # Please see AUTHORS for contributors. -# https://github.com/DeepLabCut/DeepLabCut/blob/master/AUTHORS +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS # # Licensed under GNU Lesser General Public License v3.0 # import os from datetime import datetime +from PySide6 import QtCore, QtWidgets +from PySide6.QtGui import QBrush, QColor, QDesktopServices, QIcon, QPainter, QPen + import deeplabcut -from deeplabcut.utils import auxiliaryfunctions from deeplabcut.gui import BASE_DIR from deeplabcut.gui.dlc_params import DLCParams from deeplabcut.gui.widgets import ClickableLabel, ItemSelectionFrame +from deeplabcut.gui.tabs.docs import ( + URL_3D, + URL_MA_CONFIGURE, + URL_USE_GUIDE_SCENARIO, +) +from deeplabcut.utils import auxiliaryfunctions -from PySide6 import QtCore, QtWidgets -from PySide6.QtGui import QIcon + +class DynamicTextList(QtWidgets.QWidget): + """Dynamically add text entries""" + + def __init__(self, label_text="bodyparts", parent=None): + super(DynamicTextList, self).__init__(parent) + self.label_text = label_text + self.layout = QtWidgets.QVBoxLayout(self) + self.layout.setContentsMargins(0, 0, 0, 0) + + # Set maximum width for the widget + self.setMaximumWidth(300) + + # Add explanatory label + label = QtWidgets.QLabel(label_text) + self.layout.addWidget(label) + + # Create scroll area and its widget + self.scroll = QtWidgets.QScrollArea() + self.scroll.setWidgetResizable(True) + self.scroll.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.scroll.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) + self.scroll.setFrameShape(QtWidgets.QFrame.NoFrame) # Remove frame border + + # Create widget to hold the entries + self.entries_widget = QtWidgets.QWidget() + self.entries_layout = QtWidgets.QVBoxLayout(self.entries_widget) + self.entries_layout.setContentsMargins(0, 0, 0, 0) + self.entries_layout.setSpacing(5) # Consistent spacing between entries + self.entries_layout.setAlignment(QtCore.Qt.AlignTop) # Align entries to top + + # Add stretch at the bottom to keep entries at top + self.entries_layout.addStretch() + + self.scroll.setWidget(self.entries_widget) + + # Set fixed height for 6 items + self.entry_height = 30 # Fixed height for each entry + self.padding = 10 # Extra padding + self.scroll.setFixedHeight(5 * self.entry_height + self.padding) + + # Add scroll area to main layout + self.layout.addWidget(self.scroll) + + self.entries = [] + self.add_entry() + + def add_entry(self): + # Create horizontal layout for index and entry + entry_layout = QtWidgets.QHBoxLayout() + entry_layout.setContentsMargins(0, 0, 10, 0) + entry_layout.setSpacing(5) # Consistent spacing between index and entry + + # Create container widget for the entry row + entry_widget = QtWidgets.QWidget() + entry_widget.setFixedHeight(self.entry_height) + entry_widget.setLayout(entry_layout) + + # Add index label + index_label = QtWidgets.QLabel(str(len(self.entries) + 1) + ".") + index_label.setFixedWidth(20) # Set fixed width for alignment + entry_layout.addWidget(index_label) + + # Add text entry + entry = QtWidgets.QLineEdit() + entry.setFixedHeight(self.entry_height - 6) # Slightly smaller than container + entry.textChanged.connect(self._on_text_changed) + entry.textEdited.connect(lambda text: self._check_for_spaces(entry, text)) + self.entries.append((entry, index_label)) # Store both widgets + entry_layout.addWidget(entry) + + # Insert the new entry before the stretch + self.entries_layout.insertWidget(len(self.entries) - 1, entry_widget) + + def _check_for_spaces(self, entry, text): + if " " in text: + msg = QtWidgets.QMessageBox() + msg.setIcon(QtWidgets.QMessageBox.Warning) + msg.setText( + f"Spaces are not allowed in the {self.label_text} list. Use underscores " + f"instead." + ) + msg.setWindowTitle("Warning") + msg.exec_() + entry.setText(entry.text().replace(" ", "_")) + + def _on_text_changed(self): + # If the last entry has text, add a new empty entry + if self.entries[-1][0].text(): + self.add_entry() + + # Remove any empty entries except the last one + entries_to_remove = [] + for i, (entry, _) in enumerate(self.entries[:-1]): + if not entry.text(): + entries_to_remove.append(i) + + for i in reversed(entries_to_remove): + entry_widget = self.entries[i][0].parent() + self.entries_layout.removeWidget(entry_widget) + entry_widget.deleteLater() + self.entries.pop(i) + + self._update_indices() # Update the indices after removal + + def get_entries(self): + return [entry[0].text() for entry in self.entries if entry[0].text()] + + def _update_indices(self): + for i, (entry, index_label) in enumerate(self.entries): + index_label.setText(str(i + 1) + ".") + + +class Switch(QtWidgets.QPushButton): + + def __init__(self, on_text="Yes", off_text="No", width=80, parent=None): + super().__init__(parent) + self.on_text = on_text + self.off_text = off_text + self.setCheckable(True) + self.setFixedWidth(width) + self.setMinimumHeight(22) + + def paintEvent(self, event): + # Colors: https://qdarkstylesheet.readthedocs.io/en/latest/color_reference.html + label = self.on_text if self.isChecked() else self.off_text + bg_color = "#00ff00" if self.isChecked() else "#9DA9B5" + + radius = 10 + width = 32 + center = self.rect().center() + + painter = QPainter(self) + painter.setRenderHint(QPainter.Antialiasing) + painter.translate(center) + painter.setBrush(QColor(69, 83, 100)) # Lighter gray background + + pen = QPen("#455364") + pen.setWidth(2) + painter.setPen(pen) + + painter.drawRoundedRect( + QtCore.QRect(-width, -radius, 2 * width, 2 * radius), radius, radius + ) + painter.setBrush(QBrush(bg_color)) + sw_rect = QtCore.QRect(-radius, -radius, width + radius, 2 * radius) + if not self.isChecked(): + sw_rect.moveLeft(-width) + + painter.drawRoundedRect(sw_rect, radius, radius) + + pen = QPen("#000000") + pen.setWidth(2) + painter.setPen(pen) + painter.drawText(sw_rect, QtCore.Qt.AlignCenter, label) class ProjectCreator(QtWidgets.QDialog): + """Project creation dialog""" + def __init__(self, parent): super(ProjectCreator, self).__init__(parent) self.parent = parent @@ -34,6 +197,19 @@ def __init__(self, parent): self.exp_default = "" self.loc_default = parent.project_folder + self.bodypart_list = None + self.individuals_list = None + self.unique_bodyparts_list = None + + self.toggle_3d = Switch() + self.toggle_3d.setChecked(False) + self.madlc_toggle = Switch() + self.madlc_toggle.setChecked(False) + self.unique_toggle = Switch() + self.unique_toggle.setChecked(False) + self.identity_toggle = Switch() + self.identity_toggle.setChecked(False) + main_layout = QtWidgets.QVBoxLayout(self) self.user_frame = self.lay_out_user_frame() self.video_frame = self.lay_out_video_frame() @@ -80,46 +256,172 @@ def lay_out_user_frame(self): grid.addWidget(self.loc_line, 2, 1) vbox.addLayout(grid) - self.madlc_box = QtWidgets.QCheckBox("Is it a multi-animal project?") - self.madlc_box.setChecked(False) - vbox.addWidget(self.madlc_box) + widget_3d = self.build_toggle_widget( + switch=self.toggle_3d, + question="Do you want to create a 3D pose estimation project?", + help_text="(What is needed for a 3D project?)", + docs_link=URL_3D, + ) + madlc_widget = self.build_toggle_widget( + switch=self.madlc_toggle, + question="Are there multiple individuals in your videos?", + help_text="(Why does this matter?)", + docs_link=URL_USE_GUIDE_SCENARIO, + ) + + # Only visible when the maDLC widget is checked + unique_widget = self.build_toggle_widget( + switch=self.unique_toggle, + question="Do you have unique bodyparts in your video?", + help_text="(What are unique bodyparts?)", + docs_link=URL_MA_CONFIGURE, + ) + unique_widget.setVisible(False) + + # Labelling with identity + identity_widget = self.build_toggle_widget( + switch=self.identity_toggle, + question="Label with identity?", + help_text="(What is labeling with identity?)", + docs_link=URL_MA_CONFIGURE, + ) + identity_widget.setVisible(False) + + vbox.addWidget(widget_3d, alignment=QtCore.Qt.AlignTop) + vbox.addWidget(madlc_widget, alignment=QtCore.Qt.AlignTop) + vbox.addWidget(unique_widget, alignment=QtCore.Qt.AlignTop) + vbox.addWidget(identity_widget, alignment=QtCore.Qt.AlignTop) + + # Create horizontal layout for the two lists + lists_layout = QtWidgets.QHBoxLayout() + lists_layout.setAlignment(QtCore.Qt.AlignTop) + + # Create both DynamicTextList widgets as class attributes + self.bodypart_list = DynamicTextList( + label_text="Bodyparts to track", + parent=self, + ) + + self.individuals_list = DynamicTextList( + label_text="Individual names", + parent=self, + ) + self.individuals_list.setVisible(False) + + self.unique_bodyparts_list = DynamicTextList( + label_text="Unique bodyparts to track", + parent=self, + ) + self.unique_bodyparts_list.setVisible(False) + + # Connect toggle state to individuals list visibility, unique, identity + self.madlc_toggle.toggled.connect(self.individuals_list.setVisible) + self.madlc_toggle.toggled.connect(unique_widget.setVisible) + self.madlc_toggle.toggled.connect(identity_widget.setVisible) + + # Connect the unique_toggle to the unique_bodyparts_list + self.unique_toggle.toggled.connect( + lambda yes: self.unique_bodyparts_list.setVisible( + yes and self.madlc_toggle.isChecked() + ) + ) + + # Connect 3d toggle to all other option visibility + self.toggle_3d.toggled.connect(lambda yes: madlc_widget.setVisible(not yes)) + self.toggle_3d.toggled.connect( + lambda checked_3d: unique_widget.setVisible( + not checked_3d and self.madlc_toggle.isChecked() + ) + ) + self.toggle_3d.toggled.connect( + lambda checked_3d: identity_widget.setVisible( + not checked_3d and self.madlc_toggle.isChecked() + ) + ) + self.toggle_3d.toggled.connect( + lambda checked_3d: self.bodypart_list.setVisible(not checked_3d) + ) + self.toggle_3d.toggled.connect( + lambda checked_3d: self.individuals_list.setVisible( + not checked_3d and self.madlc_toggle.isChecked() + ) + ) + self.toggle_3d.toggled.connect( + lambda checked_3d: self.unique_bodyparts_list.setVisible( + not checked_3d + and self.madlc_toggle.isChecked() + and self.unique_toggle.isChecked() + ) + ) + # Add both lists to the horizontal layout with top alignment + lists_layout.addWidget(self.bodypart_list, alignment=QtCore.Qt.AlignTop) + lists_layout.addWidget(self.individuals_list, alignment=QtCore.Qt.AlignTop) + lists_layout.addWidget(self.unique_bodyparts_list, alignment=QtCore.Qt.AlignTop) + + # Add the horizontal layout to the main vertical layout + vbox.addLayout(lists_layout) return user_frame + def build_toggle_widget( + self, + switch: Switch, + question: str, + help_text: str, + docs_link: str, + ) -> QtWidgets.QWidget: + toggle_layout = QtWidgets.QHBoxLayout() + toggle_layout.setContentsMargins(0, 0, 0, 0) + toggle_layout.setSpacing(10) + + toggle_label = QtWidgets.QLabel(question) + toggle_label.setAlignment(QtCore.Qt.AlignLeft) + help_label = ClickableLabel(help_text, parent=self) + help_label.setStyleSheet("text-decoration: underline; font-weight: bold;") + help_label.setCursor(QtCore.Qt.PointingHandCursor) + help_label.signal.connect( + lambda: QDesktopServices.openUrl(QtCore.QUrl(docs_link)) + ) + + toggle_layout.addWidget(switch, alignment=QtCore.Qt.AlignLeft) + toggle_layout.addWidget(toggle_label, alignment=QtCore.Qt.AlignLeft) + toggle_layout.addStretch() + toggle_layout.addWidget(help_label, alignment=QtCore.Qt.AlignRight) + toggle_widget = QtWidgets.QWidget() + toggle_widget.setLayout(toggle_layout) + return toggle_widget + def lay_out_video_frame(self): video_frame = ItemSelectionFrame([], self) - self.cam_combo = QtWidgets.QComboBox(video_frame) - self.cam_combo.addItems(map(str, (1, 2))) - self.cam_combo.currentTextChanged.connect(self.check_num_cameras) - ncam_label = QtWidgets.QLabel("Number of cameras:") - ncam_label.setBuddy(self.cam_combo) - self.copy_box = QtWidgets.QCheckBox("Copy videos to project folder") self.copy_box.setChecked(False) - browse_button = QtWidgets.QPushButton("Browse videos") + browse_button = QtWidgets.QPushButton("Browse folders for videos") browse_button.clicked.connect(self.browse_videos) clear_button = QtWidgets.QPushButton("Clear") clear_button.clicked.connect(video_frame.fancy_list.clear) - layout1 = QtWidgets.QHBoxLayout() - layout1.addWidget(ncam_label) - layout1.addWidget(self.cam_combo) - layout2 = QtWidgets.QHBoxLayout() - layout2.addWidget(browse_button) - layout2.addWidget(clear_button) - video_frame.layout.insertLayout(0, layout1) - video_frame.layout.addLayout(layout2) + layout = QtWidgets.QHBoxLayout() + layout.addWidget(browse_button) + layout.addWidget(clear_button) + video_frame.layout.addLayout(layout) video_frame.layout.addWidget(self.copy_box) + self.toggle_3d.toggled.connect(lambda yes: self.copy_box.setVisible(not yes)) + self.toggle_3d.toggled.connect(lambda yes: browse_button.setVisible(not yes)) + self.toggle_3d.toggled.connect(lambda yes: clear_button.setVisible(not yes)) + self.toggle_3d.toggled.connect(lambda yes: video_frame.setVisible(not yes)) return video_frame def browse_videos(self): + options = QtWidgets.QFileDialog.Options() + options |= QtWidgets.QFileDialog.DontUseNativeDialog folder = QtWidgets.QFileDialog.getExistingDirectory( self, "Please select a folder", self.loc_default, + options, ) if not folder: return @@ -128,7 +430,7 @@ def browse_videos(self): folder, relative=False, ): - if os.path.splitext(video)[1][1:] in DLCParams.VIDEOTYPES[1:]: + if os.path.splitext(video)[1][1:].lower() in DLCParams.VIDEOTYPES[1:]: self.video_frame.fancy_list.add_item(video) def finalize_project(self): @@ -142,13 +444,13 @@ def finalize_project(self): if empty: return - n_cameras = int(self.cam_combo.currentText()) + create_3d = self.toggle_3d.isChecked() try: - if n_cameras > 1: + if create_3d: _ = deeplabcut.create_new_project_3d( self.proj_default, self.exp_default, - n_cameras, + 2, self.loc_default, ) else: @@ -162,7 +464,7 @@ def finalize_project(self): self.video_frame.fancy_list._default_style ) to_copy = self.copy_box.isChecked() - is_madlc = self.madlc_box.isChecked() + is_madlc = self.madlc_toggle.isChecked() config = deeplabcut.create_new_project( self.proj_default, self.exp_default, @@ -171,16 +473,44 @@ def finalize_project(self): to_copy, multianimal=is_madlc, ) + + if self.bodypart_list is not None: + bodypart_key = "bodyparts" + updates = {} + if is_madlc: + bodypart_key = "multianimalbodyparts" + if self.individuals_list is not None: + individuals = self.individuals_list.get_entries() + if len(individuals) > 0: + updates["individuals"] = individuals + + if ( + self.unique_toggle.isChecked() + and self.unique_bodyparts_list is not None + ): + unique_bodyparts = self.unique_bodyparts_list.get_entries() + if len(unique_bodyparts) > 0: + updates["uniquebodyparts"] = unique_bodyparts + + if self.identity_toggle.isChecked(): + updates["identity"] = True + + bodyparts = self.bodypart_list.get_entries() + if len(bodyparts) > 0: + updates[bodypart_key] = bodyparts + + if len(updates) > 0: + cfg: dict = auxiliaryfunctions.read_config(config) + cfg.update(**updates) + auxiliaryfunctions.write_config(config, cfg) + self.parent.load_config(config) - self.parent._update_project_state( - config=config, - loaded=True, - ) + self.parent._update_project_state(config=config, loaded=True) except FileExistsError: print('Project "{}" already exists!'.format(self.proj_default)) return - msg = QtWidgets.QMessageBox(text=f"New project created") + msg = QtWidgets.QMessageBox(text="New project created") msg.setIcon(QtWidgets.QMessageBox.Information) msg.exec_() @@ -195,15 +525,6 @@ def on_click(self): self.loc_default = dirname self.update_project_location() - def check_num_cameras(self, value): - val = int(value) - for child in self.video_frame.children(): - if child.isWidgetType() and not isinstance(child, QtWidgets.QComboBox): - if val > 1: - child.setDisabled(True) - else: - child.setDisabled(False) - def update_project_name(self, text): self.proj_default = text self.update_project_location() diff --git a/deeplabcut/gui/tabs/create_training_dataset.py b/deeplabcut/gui/tabs/create_training_dataset.py index 7101b0e5b3..15c06418c3 100644 --- a/deeplabcut/gui/tabs/create_training_dataset.py +++ b/deeplabcut/gui/tabs/create_training_dataset.py @@ -8,12 +8,23 @@ # # Licensed under GNU Lesser General Public License v3.0 # +from __future__ import annotations + import os +from pathlib import Path +import dlclibrary from PySide6 import QtWidgets -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, Slot from PySide6.QtGui import QIcon +import deeplabcut +import deeplabcut.compat as compat +from deeplabcut.core.engine import Engine +from deeplabcut.core.weight_init import WeightInitialization +from deeplabcut.generate_training_dataset import get_existing_shuffle_indices +from deeplabcut.generate_training_dataset.metadata import get_shuffle_engine +from deeplabcut.gui.displays.shuffle_metadata_viewer import ShuffleMetadataViewer from deeplabcut.gui.dlc_params import DLCParams from deeplabcut.gui.components import ( DefaultTab, @@ -21,8 +32,8 @@ _create_grid_layout, _create_label_widget, ) - -import deeplabcut +from deeplabcut.gui.widgets import launch_napari +from deeplabcut.modelzoo import build_weight_init from deeplabcut.utils.auxiliaryfunctions import ( get_data_and_metadata_filenames, get_training_set_folder, @@ -40,16 +51,36 @@ def __init__(self, root, parent, h1_description): self._generate_layout_attributes(self.layout_attributes) self.main_layout.addLayout(self.layout_attributes) + self.mapping_button = QtWidgets.QPushButton("Edit Conversion Table") + self.mapping_button.clicked.connect(self.edit_conversion_table) + self.mapping_button.setVisible(False) + self.root.engine_change.connect(self.set_edit_table_visibility) + self.ok_button = QtWidgets.QPushButton("Create Training Dataset") self.ok_button.setMinimumWidth(150) self.ok_button.clicked.connect(self.create_training_dataset) + self.main_layout.addWidget(self.mapping_button, alignment=Qt.AlignRight) self.main_layout.addWidget(self.ok_button, alignment=Qt.AlignRight) + self.view_shuffles_button = QtWidgets.QPushButton("View Existing Shuffles") + self.view_shuffles_button.clicked.connect(self.view_shuffles) + self.main_layout.addWidget(self.view_shuffles_button, alignment=Qt.AlignLeft) + self.help_button = QtWidgets.QPushButton("Help") self.help_button.clicked.connect(self.show_help_dialog) self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft) + def set_edit_table_visibility(self) -> None: + has_conversion_tables = bool( + self.root.cfg.get("SuperAnimalConversionTables", {}) + ) + is_pytorch_engine = self.root.engine == Engine.PYTORCH + is_finetuning = self.weight_init_selector.with_decoder + self.mapping_button.setVisible( + has_conversion_tables & is_pytorch_engine & is_finetuning + ) + def show_help_dialog(self): dialog = QtWidgets.QDialog(self) layout = QtWidgets.QVBoxLayout() @@ -74,29 +105,77 @@ def _generate_layout_attributes(self, layout): shuffle_label = QtWidgets.QLabel("Shuffle") self.shuffle = ShuffleSpinBox(root=self.root, parent=self) + # Dataset choices + self.weight_init_label = QtWidgets.QLabel("Weight Initialization") + self.weight_init_selector = WeightInitializationSelector(self.root) + self.update_weight_init_methods(self.root.engine) + self.root.engine_change.connect(self.update_weight_init_methods) + # Augmentation method augmentation_label = QtWidgets.QLabel("Augmentation method") self.aug_choice = QtWidgets.QComboBox() - self.aug_choice.addItems(DLCParams.IMAGE_AUGMENTERS) - self.aug_choice.setCurrentText("imgaug") + self.update_aug_methods(self.root.engine) + self.root.engine_change.connect(self.update_aug_methods) self.aug_choice.currentTextChanged.connect(self.log_augmentation_choice) # Neural Network nnet_label = QtWidgets.QLabel("Network architecture") self.net_choice = QtWidgets.QComboBox() - nets = DLCParams.NNETS.copy() - if not self.root.is_multianimal: - nets.remove("dlcrnet_ms5") - self.net_choice.addItems(nets) - self.net_choice.setCurrentText("resnet_50") + self.net_choice.setMinimumWidth(200) + self.update_nets(self.root.engine) + self.root.engine_change.connect(self.update_nets) self.net_choice.currentTextChanged.connect(self.log_net_choice) + # Update Net types when selected weight init changes + self.weight_init_selector.weight_init_choice.currentTextChanged.connect( + lambda _: self.update_nets(None) + ) + self.weight_init_selector.weight_init_choice.currentTextChanged.connect( + lambda _: self.set_edit_table_visibility() + ) + + # Detector selection for top-down models + self.detector_label = QtWidgets.QLabel("Detector architecture") + self.detector_choice = QtWidgets.QComboBox() + self.detector_choice.setMinimumWidth(200) + self.update_detectors(engine=self.root.engine) + self.root.engine_change.connect( + lambda engine: self.update_detectors(engine=engine) + ) + self.net_choice.currentTextChanged.connect( + lambda new_net_choice: self.update_detectors(net_choice=new_net_choice) + ) + + # Overwrite selection + self.overwrite = QtWidgets.QCheckBox("Overwrite if exists") + self.overwrite.setChecked(False) + self.overwrite.setToolTip( + "When checked, creating a new shuffle with an index that already exists " + "will overwrite the existing index. Be careful with this option as you " + "might lose data." + ) + self.overwrite.stateChanged.connect( + lambda s: self.root.logger.info(f"Overwrite: {s}") + ) + + # Use same data split as another shuffle + self.data_split_selection = DataSplitSelector(self.root, self) + layout.addWidget(shuffle_label, 0, 0) layout.addWidget(self.shuffle, 0, 1) - layout.addWidget(nnet_label, 0, 2) - layout.addWidget(self.net_choice, 0, 3) - layout.addWidget(augmentation_label, 0, 4) - layout.addWidget(self.aug_choice, 0, 5) + layout.addWidget(self.weight_init_label, 0, 2) + layout.addWidget(self.weight_init_selector, 0, 3) + + layout.addWidget(nnet_label, 1, 0) + layout.addWidget(self.net_choice, 1, 1) + layout.addWidget(augmentation_label, 1, 2) + layout.addWidget(self.aug_choice, 1, 3) + + layout.addWidget(self.detector_label, 2, 0) + layout.addWidget(self.detector_choice, 2, 1) + + layout.addWidget(self.overwrite, 3, 0) + layout.addWidget(self.data_split_selection, 4, 0) def log_net_choice(self, net): self.root.logger.info(f"Network architecture set to {net.upper()}") @@ -104,34 +183,134 @@ def log_net_choice(self, net): def log_augmentation_choice(self, augmentation): self.root.logger.info(f"Image augmentation set to {augmentation.upper()}") + def edit_conversion_table(self): + # Test beforehand whether a conversion table exists + memory_replay_folder = Path(self.root.project_folder) / "memory_replay" + conversion_matrix_out_path = str(memory_replay_folder / "confusion_matrix.png") + files = [self.root.config] + if os.path.exists(conversion_matrix_out_path): + files.append(conversion_matrix_out_path) + _ = launch_napari(files) + def create_training_dataset(self): shuffle = self.shuffle.value() + cfg = self.root.cfg + existing_indices = get_existing_shuffle_indices( + cfg=cfg, train_fraction=cfg["TrainingFraction"][self.root.trainingset_index] + ) + + overwrite = self.overwrite.isChecked() + if shuffle in existing_indices: + if overwrite: + if not self._confirm_overwrite(shuffle, existing_indices): + return + else: + msg = _create_message_box( + f"The training dataset could not be created.", + ( + f"Shuffle {shuffle} already exists - you can create a new " + "training dataset with an unused shuffle index (existing " + f"shuffles are {existing_indices}) or you can overwrite the " + f"shuffle by ticking the 'Overwrite' checkbox" + ), + ) + msg.exec_() + self.root.writer.write("Training dataset creation failed.") + return if self.model_comparison: raise NotImplementedError # TODO: finish model_comparison - deeplabcut.create_training_model_comparison( - config_file, - num_shuffles=shuffle, - net_types=self.net_type, - augmenter_types=self.aug_type, - ) + # deeplabcut.create_training_model_comparison( + # config_file, + # num_shuffles=shuffle, + # net_types=self.net_type, + # augmenter_types=self.aug_type, + # ) else: - if self.root.is_multianimal: - deeplabcut.create_multianimaltraining_dataset( - self.root.config, - shuffle, - Shuffles=[self.shuffle.value()], - net_type=self.net_choice.currentText(), + try: + engine = self.root.engine + net_type = self.net_choice.currentText() + detector_type = None + if engine == Engine.TF: + import tensorflow + + # try importing TF so they can't create shuffles for it if they + # don't have it installed + elif engine == Engine.PYTORCH and "top_down" in net_type: + detector_type = self.detector_choice.currentText() + + try: + weight_init = ( + self.weight_init_selector.get_super_animal_weight_init( + net_type, + detector_type, + ) + ) + except ValueError as err: + print(f"The training dataset could not be created: {err}.") + return + + if self.data_split_selection.selected: + deeplabcut.create_training_dataset_from_existing_split( + self.root.config, + from_shuffle=self.data_split_selection.from_shuffle, + shuffles=[self.shuffle.value()], + net_type=net_type, + detector_type=detector_type, + userfeedback=not overwrite, + weight_init=weight_init, + engine=engine, + ) + + elif self.root.is_multianimal: + deeplabcut.create_multianimaltraining_dataset( + self.root.config, + shuffle, + Shuffles=[self.shuffle.value()], + net_type=net_type, + detector_type=detector_type, + userfeedback=not overwrite, + weight_init=weight_init, + engine=engine, + ) + else: + deeplabcut.create_training_dataset( + self.root.config, + shuffle, + Shuffles=[self.shuffle.value()], + net_type=net_type, + detector_type=detector_type, + augmenter_type=self.aug_choice.currentText(), + userfeedback=not overwrite, + weight_init=weight_init, + engine=engine, + ) + except ValueError as err: + msg = _create_message_box( + f"The training dataset could not be created.", + str(err), ) - else: - deeplabcut.create_training_dataset( - self.root.config, - shuffle, - Shuffles=[self.shuffle.value()], - net_type=self.net_choice.currentText(), - augmenter_type=self.aug_choice.currentText(), + msg.exec_() + return + except ModuleNotFoundError as err: + info_text = ( + f"Error `{err}`. If the error is `ModuleNotFoundError: No module " + "named 'tensorflow'`, this is because you tried creating a " + "TensorFlow shuffle, but TensorFlow is not installed in your " + "environment. To create TensorFlow shuffles (and use TensorFlow " + "models), install it with\n" + " Windows/Linux:\n" + " pip install 'deeplabcut[tf]'\n" + " Apple Silicon:\n" + " pip install 'deeplabcut[apple_mchips]'" ) + msg = _create_message_box( + f"The training dataset could not be created.", info_text + ) + msg.exec_() + return + # Check that training data files were indeed created. trainingsetfolder = get_training_set_folder(self.root.cfg) filenames = list( @@ -148,6 +327,7 @@ def create_training_dataset(self): os.path.exists(os.path.join(self.root.project_folder, file)) for file in filenames ): + self.root.shuffle_created.emit(self.shuffle.value()) msg = _create_message_box( "The training dataset is successfully created.", "Use the function 'train_network' to start training. Happy training!", @@ -162,6 +342,369 @@ def create_training_dataset(self): msg.exec_() self.root.writer.write("Training dataset creation failed.") + def _confirm_overwrite(self, shuffle: int, existing_indices: list[int]) -> bool: + """ + Asks the user to confirm that they want to overwrite a shuffle. + + Args: + shuffle: the shuffle the user wants to overwrite + existing_indices: the indices of existing shuffles + + Returns: + whether the user confirmed overwriting the shuffle + """ + try: + engine = get_shuffle_engine( + self.root.cfg, self.root.trainingset_index, shuffle + ) + engine_str = f" (with engine '{engine.aliases[0]}')" + except ValueError: + engine_str = "" + + conf = _create_confirmation_box( + title=f"Are you sure you want to overwrite shuffle {shuffle}?", + description=( + f"As shuffle {shuffle} already exists{engine_str}, " + f"the training-dataset files would be overwritten." + ), + ) + result = conf.exec() + if result != QtWidgets.QMessageBox.Yes: + msg = _create_message_box( + text="The training dataset was not be created.", + info_text=( + "You can create a shuffle with another index. Existing indices " + f"are {existing_indices}" + ), + ) + msg.exec_() + self.root.writer.write("Training dataset creation interrupted.") + return False + + return True + + @Slot(Engine) + def update_nets(self, engine: Engine | None) -> None: + if engine is None: + engine = self.root.engine + + default_net = None + if engine == Engine.TF: + nets = DLCParams.NNETS.copy() + if not self.root.is_multianimal: + nets.remove("dlcrnet_ms5") + else: + # FIXME: Circular imports make it impossible to import this at the top + from deeplabcut.pose_estimation_pytorch import available_models + + nets = available_models() + net_filter = self.get_net_filter() + default_net = self.get_default_net() + td_prefix = "top_down_" + if net_filter is not None: + nets = [ + n + for n in nets + if ( + n in net_filter + or ( + n.startswith(td_prefix) + and n[len(td_prefix) :] in net_filter + ) + ) + ] + + while self.net_choice.count() > 0: + self.net_choice.removeItem(0) + + self.net_choice.addItems(nets) + if default_net is None: + default_net = self.root.cfg.get("default_net_type", "resnet_50") + + if default_net in nets: + self.net_choice.setCurrentIndex(nets.index(default_net)) + + @Slot(Engine) + def update_detectors( + self, + engine: Engine | None = None, + net_choice: str | None = None, + ) -> None: + if engine is None: + engine = self.root.engine + + if engine == Engine.TF: + detectors = [] + else: + # FIXME: Circular imports make it impossible to import this at the top + from deeplabcut.pose_estimation_pytorch import available_detectors + + detectors = available_detectors() + det_filter = self.get_detector_filter() + if det_filter is not None: + detectors = [d for d in detectors if d in det_filter] + + while self.detector_choice.count() > 0: + self.detector_choice.removeItem(0) + + self.detector_choice.addItems(detectors) + default_detector = self.get_default_detector() + if default_detector in detectors: + self.detector_choice.setCurrentIndex(detectors.index(default_detector)) + elif "ssdlite" in detectors: + self.detector_choice.setCurrentIndex(detectors.index("ssdlite")) + + if net_choice is None: + net_choice = self.net_choice.currentText() + + if "top_down" in net_choice: + self.detector_label.show() + self.detector_choice.show() + else: + self.detector_label.hide() + self.detector_choice.hide() + + @Slot(Engine) + def update_aug_methods(self, engine: Engine) -> None: + methods = compat.get_available_aug_methods(engine) + while self.aug_choice.count() > 0: + self.aug_choice.removeItem(0) + + self.aug_choice.addItems(methods) + self.aug_choice.setCurrentText(methods[0]) + + @Slot(Engine) + def update_weight_init_methods(self, engine: Engine) -> None: + if engine != Engine.PYTORCH: + self.weight_init_label.hide() + self.weight_init_selector.hide() + return + + self.weight_init_label.show() + self.weight_init_selector.update_choices(list(_WEIGHT_INIT_OPTIONS.keys())) + self.weight_init_selector.show() + + def get_net_filter(self) -> list[str] | None: + """Returns: the net type that can be used based on weight initialization""" + if self.root.engine != Engine.PYTORCH: + return None + + if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS: + return None + + weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] + if "super_animal" in weight_init_cfg: + return dlclibrary.get_available_models(weight_init_cfg["super_animal"]) + + return None + + def get_detector_filter(self) -> list[str] | None: + """Returns: the detectors that can be used based on weight initialization""" + if self.root.engine != Engine.PYTORCH: + return None + + if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS: + return None + + weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] + if "super_animal" in weight_init_cfg: + return dlclibrary.get_available_detectors(weight_init_cfg["super_animal"]) + + return None + + def get_default_net(self) -> str | None: + """Returns: the net type that can be used based on weight initialization""" + if self.root.engine != Engine.PYTORCH: + return None + + if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS: + return None + + weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] + return weight_init_cfg.get("default_net") + + def get_default_detector(self) -> str | None: + """Returns: the detector type that can be used based on weight initialization""" + if self.root.engine != Engine.PYTORCH: + return None + + if self.weight_init_selector.weight_init not in _WEIGHT_INIT_OPTIONS: + return None + + weight_init_cfg = _WEIGHT_INIT_OPTIONS[self.weight_init_selector.weight_init] + return weight_init_cfg.get("default_detector") + + def view_shuffles(self) -> None: + viewer = ShuffleMetadataViewer(root=self.root, parent=self) + viewer.show() + + +class WeightInitializationSelector(QtWidgets.QWidget): + """Widget to select weight initialization""" + + def __init__(self, root): + super().__init__() + self.root = root + + self.weight_init_choice = QtWidgets.QComboBox() + + self.memory_replay_label = QtWidgets.QLabel("With memory replay") + self.memory_replay_box = QtWidgets.QCheckBox() + self.memory_replay_label.hide() + self.memory_replay_box.hide() + + memory_replay_layout = QtWidgets.QHBoxLayout() + memory_replay_layout.addWidget(self.memory_replay_label) + memory_replay_layout.addWidget(self.memory_replay_box) + + layout = QtWidgets.QHBoxLayout() + layout.addWidget(self.weight_init_choice) + layout.addLayout(memory_replay_layout) + self.setLayout(layout) + + self.weight_init_choice.currentTextChanged.connect(self._choice_changed) + + @property + def weight_init(self) -> str: + return self.weight_init_choice.currentText() + + @property + def with_decoder(self) -> bool: + weight_init_choice = self.weight_init_choice.currentText() + return "fine-tuning" in weight_init_choice.lower() + + @property + def memory_replay(self) -> bool: + return self.memory_replay_box.isChecked() + + def update_choices(self, choices: list[str]) -> None: + """Updates the WeightInitialization methods that can be selected""" + while self.weight_init_choice.count() > 0: + self.weight_init_choice.removeItem(0) + self.weight_init_choice.addItems(choices) + + def get_super_animal_weight_init( + self, + net_type: str, + detector_type: str, + ) -> WeightInitialization | None: + """ + Args: + net_type: The architecture of the pose model from which to fine-tune a + SuperAnimal model. + detector_type: The architecture of the detector from which to fine-tune a + SuperAnimal model. + + Raises: + ValueError if WeightInitialization should be defined but could not be + created (e.g. if there's no conversion table). + """ + if self.root.engine != Engine.PYTORCH: + return None + + weight_init_choice = self.weight_init_choice.currentText() + if "imagenet" in weight_init_choice.lower(): + return + + weight_init_data = _WEIGHT_INIT_OPTIONS[weight_init_choice] + super_animal = weight_init_data["super_animal"] + if net_type.startswith("top_down_"): + net_type = net_type[len("top_down_") :] + try: + weight_init = build_weight_init( + self.root.cfg, + super_animal=super_animal, + model_name=net_type, + detector_name=detector_type, + with_decoder=self.with_decoder, + memory_replay=self.memory_replay, + ) + except ValueError as err: + QtWidgets.QMessageBox.critical( + self, + "Error", + ( + f"No Conversion table specified for {super_animal} in the project " + "configuration file. Please create a conversion table using the GUI" + ", with ``deeplabcut.modelzoo.utils.create_conversion_table``, or " + "by adding it to your project's configuration file manually." + ), + ) + raise err + + return weight_init + + def _choice_changed(self, state: str) -> None: + if "fine-tuning" in str(state).lower(): + self.memory_replay_label.show() + self.memory_replay_box.show() + else: + self.memory_replay_label.hide() + self.memory_replay_box.hide() + + +class DataSplitSelector(QtWidgets.QWidget): + """Allows users to create training sets with the same train/test split as another""" + + def __init__(self, root: QtWidgets.QMainWindow, parent: QtWidgets.QWidget): + super().__init__() + self.root = root + self.parent = parent + + self.setToolTip( + "This allows you to create a shuffle where the data split is the same as " + "one of your existing shuffles (the images on which the model is " + "trained/tested are the same)." + ) + + layout = QtWidgets.QVBoxLayout() + layout.setSpacing(0) + layout.setContentsMargins(0, 0, 0, 0) + + box_layout = QtWidgets.QHBoxLayout() + box_layout.setSpacing(0) + box_layout.setContentsMargins(0, 0, 0, 0) + + selector_layout = QtWidgets.QHBoxLayout() + selector_layout.setSpacing(0) + selector_layout.setContentsMargins(0, 0, 0, 0) + + self.shuffle_label = QtWidgets.QLabel("From shuffle:") + self.shuffle_label.hide() + self.shuffle_selector = QtWidgets.QSpinBox() + self.shuffle_selector.setMaximum(10_000) + self.shuffle_selector.setValue(0) + self.shuffle_selector.hide() + + self.box = QtWidgets.QCheckBox(parent=self) + self.box.stateChanged.connect(self._checkbox_status_changed) + self.box_label = QtWidgets.QLabel("Use an existing data split") + + box_layout.addWidget(self.box) + box_layout.addWidget(self.box_label) + selector_layout.addWidget(self.shuffle_label) + selector_layout.addWidget(self.shuffle_selector) + layout.addLayout(box_layout) + layout.addLayout(selector_layout) + self.setLayout(layout) + + @property + def selected(self) -> bool: + return self.box.isChecked() + + @property + def from_shuffle(self) -> int: + """The shuffle from which to copy the data split""" + return self.shuffle_selector.value() + + def _checkbox_status_changed(self, state: int) -> None: + if Qt.CheckState(state) == Qt.Checked: + self.shuffle_selector.show() + self.shuffle_label.show() + else: + self.shuffle_selector.hide() + self.shuffle_label.hide() + def _create_message_box(text, info_text): msg = QtWidgets.QMessageBox() @@ -176,3 +719,56 @@ def _create_message_box(text, info_text): msg.setWindowIcon(QIcon(logo)) msg.setStandardButtons(QtWidgets.QMessageBox.Ok) return msg + + +def _create_confirmation_box(title, description): + msg = QtWidgets.QMessageBox() + msg.setIcon(QtWidgets.QMessageBox.Information) + msg.setText(title) + msg.setInformativeText(description) + + msg.setWindowTitle("Confirmation") + msg.setMinimumWidth(900) + logo_dir = os.path.dirname(os.path.realpath("logo.png")) + os.path.sep + logo = logo_dir + "/assets/logo.png" + msg.setWindowIcon(QIcon(logo)) + msg.setStandardButtons(QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No) + return msg + + +_WEIGHT_INIT_OPTIONS = { # FIXME - Generate dynamically + "Transfer Learning - ImageNet": { + "model_filter": None, + "detector_filter": None, + }, + "Transfer Learning - SuperAnimal Bird": { + "default_net": "top_down_resnet_50", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_bird", + }, + "Transfer Learning - SuperAnimal Quadruped": { + "default_net": "top_down_hrnet_w32", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_quadruped", + }, + "Transfer Learning - SuperAnimal TopViewMouse": { + "default_net": "top_down_hrnet_w32", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_topviewmouse", + }, + "Fine-tuning - SuperAnimal Bird": { + "default_net": "top_down_resnet_50", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_bird", + }, + "Fine-tuning - SuperAnimal Quadruped": { + "default_net": "top_down_hrnet_w32", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_quadruped", + }, + "Fine-tuning - SuperAnimal TopViewMouse": { + "default_net": "top_down_hrnet_w32", + "default_detector": "fasterrcnn_mobilenet_v3_large_fpn", + "super_animal": "superanimal_topviewmouse", + }, +} diff --git a/deeplabcut/gui/tabs/create_videos.py b/deeplabcut/gui/tabs/create_videos.py index d8299b9445..9cdf92789e 100644 --- a/deeplabcut/gui/tabs/create_videos.py +++ b/deeplabcut/gui/tabs/create_videos.py @@ -134,7 +134,7 @@ def _generate_layout_video_parameters(self, layout): # Skeleton self.draw_skeleton_checkbox = QtWidgets.QCheckBox("Draw skeleton") - self.draw_skeleton_checkbox.setCheckState(Qt.Checked) + self.draw_skeleton_checkbox.setCheckState(Qt.Unchecked) self.draw_skeleton_checkbox.stateChanged.connect(self.update_draw_skeleton) tmp_layout.addWidget(self.draw_skeleton_checkbox) @@ -146,6 +146,24 @@ def _generate_layout_video_parameters(self, layout): ) tmp_layout.addWidget(self.use_filtered_data_checkbox) + # Selector for p-cutoff + pcutoff_widget = QtWidgets.QWidget() + pcutoff_layout = _create_horizontal_layout(margins=(0, 0, 0, 0)) + pcutoff_label = QtWidgets.QLabel("Plotting confidence cutoff (pcutoff)") + self.pcutoff_selector = QtWidgets.QDoubleSpinBox() + self.pcutoff_selector.setMinimum(0.0) + self.pcutoff_selector.setMaximum(1.0) + self.pcutoff_selector.setValue(0.6) + self.pcutoff_selector.setSingleStep(0.05) + pcutoff_layout.addWidget(pcutoff_label) + pcutoff_layout.addWidget(self.pcutoff_selector) + pcutoff_widget.setLayout(pcutoff_layout) + pcutoff_widget.setToolTip( + "This value sets the confidence threshold, above which predictions are " + "shown in the labeled videos." + ) + tmp_layout.addWidget(pcutoff_widget) + # Plot trajectories self.plot_trajectories = QtWidgets.QCheckBox("Plot trajectories") self.plot_trajectories.setCheckState(Qt.Unchecked) @@ -179,11 +197,11 @@ def _generate_layout_video_parameters(self, layout): layout.addLayout(tmp_layout, Qt.AlignLeft) def update_high_quality_video(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"High quality {s}.") def update_plot_trajectory_choice(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Plot trajectories {s}.") def update_selected_bodyparts(self): @@ -196,7 +214,7 @@ def update_selected_bodyparts(self): self.bodyparts_to_use = selected_bodyparts def update_use_all_bodyparts(self, s): - if s == Qt.Checked: + if Qt.CheckState(s) == Qt.Checked: self.bodyparts_list_widget.setEnabled(False) self.bodyparts_list_widget.hide() self.root.logger.info("Plot all bodyparts ENABLED.") @@ -207,15 +225,15 @@ def update_use_all_bodyparts(self, s): self.root.logger.info("Plot all bodyparts DISABLED.") def update_use_filtered_data(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Use filtered data {s}") def update_draw_skeleton(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Draw skeleton {s}") def update_overwrite_videos(self, state): - s = "ENABLED" if state == Qt.Checked else "DISABLED" + s = "ENABLED" if Qt.CheckState(state) == Qt.Checked else "DISABLED" self.root.logger.info(f"Overwrite videos {s}") def update_color_by(self, text): @@ -259,10 +277,12 @@ def create_videos(self): shuffle=shuffle, filtered=filtered, save_frames=self.create_high_quality_video.isChecked(), + pcutoff=self.pcutoff_selector.value(), displayedbodyparts=bodyparts, draw_skeleton=self.draw_skeleton_checkbox.isChecked(), trailpoints=trailpoints, color_by=color_by, + overwrite=self.overwrite_videos.isChecked(), ) if all(videos_created): self.root.writer.write("Labeled videos created.") @@ -280,6 +300,7 @@ def create_videos(self): shuffle=shuffle, filtered=filtered, displayedbodyparts=bodyparts, + pcutoff=self.pcutoff_selector.value(), ) def build_skeleton(self, *args): diff --git a/deeplabcut/gui/tabs/docs.py b/deeplabcut/gui/tabs/docs.py new file mode 100644 index 0000000000..1f52118e19 --- /dev/null +++ b/deeplabcut/gui/tabs/docs.py @@ -0,0 +1,15 @@ +# +# DeepLabCut Toolbox (deeplabcut.org) +# © A. & M.W. Mathis Labs +# https://github.com/DeepLabCut/DeepLabCut +# +# Please see AUTHORS for contributors. +# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS +# +# Licensed under GNU Lesser General Public License v3.0 +# +BASE_URL = "https://deeplabcut.github.io/DeepLabCut/docs/" +README = "https://deeplabcut.github.io/DeepLabCut/README.html" +URL_3D = BASE_URL + "Overviewof3D.html" +URL_MA_CONFIGURE = BASE_URL + "maDLC_UserGuide.html#configure-the-project" +URL_USE_GUIDE_SCENARIO = BASE_URL + "UseOverviewGuide.html#what-scenario-do-you-have" diff --git a/deeplabcut/gui/tabs/evaluate_network.py b/deeplabcut/gui/tabs/evaluate_network.py index 36389224db..9b46ee09ce 100644 --- a/deeplabcut/gui/tabs/evaluate_network.py +++ b/deeplabcut/gui/tabs/evaluate_network.py @@ -8,17 +8,21 @@ # # Licensed under GNU Lesser General Public License v3.0 # +from __future__ import annotations + import os import matplotlib.image as mpimg from matplotlib.backends.backend_qt5agg import ( FigureCanvasQTAgg as FigureCanvas, ) from matplotlib.figure import Figure +from pathlib import Path from PySide6 import QtWidgets -from PySide6.QtCore import Qt +from PySide6.QtCore import Qt, Slot import deeplabcut -from deeplabcut.utils.auxiliaryfunctions import get_evaluation_folder +from deeplabcut.core.engine import Engine +from deeplabcut.gui.displays.selected_shuffle_display import SelectedShuffleDisplay from deeplabcut.gui.components import ( BodypartListWidget, DefaultTab, @@ -27,7 +31,8 @@ _create_label_widget, _create_vertical_layout, ) -from deeplabcut.gui.widgets import ConfigEditor +from deeplabcut.gui.widgets import ConfigEditor, launch_napari +from deeplabcut.utils import auxiliaryfunctions class GridCanvas(QtWidgets.QDialog): @@ -91,6 +96,9 @@ def _set_page(self): self.help_button.clicked.connect(self.show_help_dialog) self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft) + self.root.engine_change.connect(self._on_engine_change) + self._on_engine_change(self.root.engine) + def show_help_dialog(self): dialog = QtWidgets.QDialog(self) layout = QtWidgets.QVBoxLayout() @@ -107,9 +115,11 @@ def show_help_dialog(self): def _generate_layout_attributes(self, layout): opt_text = QtWidgets.QLabel("Shuffle") self.shuffle = ShuffleSpinBox(root=self.root, parent=self) + self.shuffle_display = SelectedShuffleDisplay(self.root, row_margin=0) layout.addWidget(opt_text) layout.addWidget(self.shuffle) + layout.addWidget(self.shuffle_display) def open_inferencecfg_editor(self): editor = ConfigEditor(self.root.inference_cfg_path) @@ -124,7 +134,7 @@ def plot_maps(self): dest_folder = os.path.join( self.root.project_folder, str( - get_evaluation_folder( + auxiliaryfunctions.get_evaluation_folder( self.root.cfg["TrainingFraction"][0], shuffle, self.root.cfg ) ), @@ -159,19 +169,19 @@ def _generate_additional_attributes(self, layout): layout.addWidget(self.bodyparts_list_widget, alignment=Qt.AlignLeft) def update_map_choice(self, state): - if state == Qt.Checked: + if Qt.CheckState(state) == Qt.Checked: self.root.logger.info("Plot scoremaps ENABLED") else: self.root.logger.info("Plot predictions DISABLED") def update_plot_predictions(self, s): - if s == Qt.Checked: + if Qt.CheckState(s) == Qt.Checked: self.root.logger.info("Plot predictions ENABLED") else: self.root.logger.info("Plot predictions DISABLED") def update_bodypart_choice(self, s): - if s == Qt.Checked: + if Qt.CheckState(s) == Qt.Checked: self.bodyparts_list_widget.setEnabled(False) self.bodyparts_list_widget.hide() self.root.logger.info("Use all bodyparts") @@ -184,8 +194,7 @@ def update_bodypart_choice(self, s): def evaluate_network(self): config = self.root.config - - Shuffles = [self.root.shuffle_value] + shuffle = self.root.shuffle_value plotting = self.plot_predictions.isChecked() bodyparts_to_use = "all" @@ -197,8 +206,38 @@ def evaluate_network(self): deeplabcut.evaluate_network( config, - Shuffles=Shuffles, + Shuffles=[shuffle], plotting=plotting, show_errors=True, comparisonbodyparts=bodyparts_to_use, ) + + if plotting: + project_cfg = self.root.cfg + eval_folder = auxiliaryfunctions.get_evaluation_folder( + trainFraction=project_cfg["TrainingFraction"][0], + shuffle=shuffle, + cfg=project_cfg, + ) + scorer, _ = auxiliaryfunctions.get_scorer_name( + cfg=project_cfg, + shuffle=shuffle, + trainFraction=project_cfg["TrainingFraction"][0], + ) + + image_dir = ( + Path(self.root.project_folder) + / eval_folder + / f"LabeledImages_{scorer}" + ) + labeled_images = [str(p) for p in image_dir.rglob("*.png")] + if len(labeled_images) > 0: + _ = launch_napari(image_dir) + + @Slot(Engine) + def _on_engine_change(self, engine: Engine) -> None: + if engine == Engine.PYTORCH: + self.opt_button.hide() + return + + self.opt_button.show() diff --git a/deeplabcut/gui/tabs/label_frames.py b/deeplabcut/gui/tabs/label_frames.py index 52261ad3b1..9eb49b276c 100644 --- a/deeplabcut/gui/tabs/label_frames.py +++ b/deeplabcut/gui/tabs/label_frames.py @@ -19,6 +19,7 @@ from deeplabcut.generate_training_dataset import check_labels from deeplabcut.gui.components import DefaultTab from deeplabcut.gui.widgets import launch_napari +from deeplabcut.utils.skeleton import SkeletonBuilder def label_frames( @@ -110,8 +111,11 @@ def _set_page(self): self.label_frames_btn.clicked.connect(self.label_frames) self.check_labels_btn = QtWidgets.QPushButton("Check Labels") self.check_labels_btn.clicked.connect(self.check_labels) + self.build_skeleton_btn = QtWidgets.QPushButton("Build skeleton") + self.build_skeleton_btn.clicked.connect(self.build_skeleton) self.main_layout.addWidget(self.label_frames_btn, alignment=Qt.AlignLeft) self.main_layout.addWidget(self.check_labels_btn, alignment=Qt.AlignLeft) + self.main_layout.addWidget(self.build_skeleton_btn, alignment=Qt.AlignLeft) def log_color_by_option(self, choice): self.root.logger.info(f"Labeled images will by colored by {choice.upper()}") @@ -136,3 +140,8 @@ def label_frames(self): def check_labels(self): check_labels(self.root.config, visualizeindividuals=self.root.is_multianimal) + labeled_images = (Path(self.root.config).parent / "labeled-data").rglob("*_labeled/*.png") + _ = launch_napari(labeled_images, plugin="napari", stack=True) + + def build_skeleton(self, *args): + SkeletonBuilder(self.root.config) \ No newline at end of file diff --git a/deeplabcut/gui/tabs/modelzoo.py b/deeplabcut/gui/tabs/modelzoo.py index 97ddc2fe07..13039517f2 100644 --- a/deeplabcut/gui/tabs/modelzoo.py +++ b/deeplabcut/gui/tabs/modelzoo.py @@ -9,21 +9,25 @@ # Licensed under GNU Lesser General Public License v3.0 # import os +import webbrowser from functools import partial -import deeplabcut +import dlclibrary from PySide6 import QtWidgets -from PySide6.QtCore import Qt, Signal, QTimer, QRegularExpression -from PySide6.QtGui import QPixmap, QRegularExpressionValidator +from PySide6.QtCore import QRegularExpression, Qt, QTimer, Signal, Slot +from PySide6.QtGui import QIcon, QPixmap, QRegularExpressionValidator + +import deeplabcut +from deeplabcut.core.engine import Engine +from deeplabcut.gui import BASE_DIR from deeplabcut.gui.components import ( + _create_grid_layout, + _create_label_widget, DefaultTab, VideoSelectionWidget, - _create_label_widget, - _create_grid_layout, ) -from deeplabcut.gui import BASE_DIR from deeplabcut.gui.utils import move_to_separate_thread -from deeplabcut.modelzoo.utils import parse_available_supermodels +from deeplabcut.gui.widgets import ClickableLabel class RegExpValidator(QRegularExpressionValidator): @@ -40,6 +44,11 @@ def __init__(self, root, parent, h1_description): super().__init__(root, parent, h1_description) self._val_pattern = QRegularExpression(r"(\d{3,5},\s*)+\d{3,5}") self._set_page() + self.root.engine_change.connect(self._on_engine_change) + self.root.engine_change.connect(self._update_available_models) + self._update_pose_models(self.model_combo.currentText()) + self._update_detectors(self.model_combo.currentText()) + self._destfolder = None @property def files(self): @@ -50,18 +59,85 @@ def _set_page(self): self.video_selection_widget = VideoSelectionWidget(self.root, self) self.main_layout.addWidget(self.video_selection_widget) - model_settings_layout = _create_grid_layout(margins=(20, 0, 0, 0)) + self._build_common_attributes() + self._build_tf_attributes() + self._build_torch_attributes() + + self.run_button = QtWidgets.QPushButton("Run") + self.run_button.clicked.connect(self.run_video_adaptation) + self.main_layout.addWidget(self.run_button, alignment=Qt.AlignRight) + + self.home_button = QtWidgets.QPushButton("Return to Welcome page") + self.home_button.clicked.connect(self.root._generate_welcome_page) + self.main_layout.addWidget(self.home_button, alignment=Qt.AlignLeft) + self.help_button = QtWidgets.QPushButton("Help") + self.help_button.clicked.connect(self.show_help_dialog) + self.main_layout.addWidget(self.help_button, alignment=Qt.AlignLeft) + self.go_to_button = QtWidgets.QPushButton("Read Documentation") + # go to url https://deeplabcut.github.io/DeepLabCut/docs/ModelZoo.html#about-the-superanimal-models when button is clicked + self.go_to_button.clicked.connect( + lambda: webbrowser.open( + "https://deeplabcut.github.io/DeepLabCut/docs/ModelZoo.html#about-the-superanimal-models" + ) + ) + self.main_layout.addWidget(self.go_to_button, alignment=Qt.AlignLeft) + self._on_engine_change(self.root.engine) + + def _build_common_attributes(self) -> None: + settings_layout = _create_grid_layout(margins=(20, 0, 0, 0)) section_title = _create_label_widget( "Supermodel Settings", "font:bold", (0, 50, 0, 0) ) model_combo_text = QtWidgets.QLabel("Supermodel name") + model_combo_text.setMinimumWidth(300) self.model_combo = QtWidgets.QComboBox() - supermodels = parse_available_supermodels() - self.model_combo.addItems(supermodels.keys()) + self.model_combo.setMinimumWidth(250) + + net_type_text = QtWidgets.QLabel("Net Type") + net_type_text.setMinimumWidth(300) + self.net_type_selector = QtWidgets.QComboBox() + + self.detector_type_text = QtWidgets.QLabel("Detector Type") + self.detector_type_text.setMinimumWidth(300) + self.detector_type_selector = QtWidgets.QComboBox() + + loc_label = ClickableLabel("Folder to store results:", parent=self) + loc_label.signal.connect(self.select_folder) + self.loc_line = QtWidgets.QLineEdit( + "