Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions _toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ parts:
chapters:
- file: docs/recipes/UsingModelZooPupil
- file: docs/recipes/MegaDetectorDLCLive
- caption: DeepLabCut Benchmark
chapters:
- file: docs/benchmark
- caption: Hardware
chapters:
- file: docs/recipes/TechHardware
Expand Down
116 changes: 116 additions & 0 deletions deeplabcut/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# DeepLabCut2.0 Toolbox (deeplabcut.org)
# © A. & M. Mathis Labs
# https://github.com/AlexEMG/DeepLabCut
# Please see AUTHORS for contributors.
#
# https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
# Licensed under GNU Lesser General Public License v3.0


import json
import os
from typing import Container
from typing import Literal

from deeplabcut.benchmark.base import Benchmark, Result, ResultCollection

DATA_ROOT = os.path.join(os.getcwd(), "data")
CACHE = os.path.join(os.getcwd(), ".results")

__registry = []


def register(cls):
"""Add a benchmark to the list of evaluations to run.

Apply this function as a decorator to a class. Note that the
class needs to be a subclass of the ``benchmark.base.Benchmark``
base class.

In most situations, it will be a subclass of one of the pre-defined
benchmarks in ``benchmark.benchmarks``.

Throws:
``ValueError`` if the decorator is applied to a class that is
not a subclass of ``benchmark.base.Benchmark``.
"""
if not issubclass(cls, Benchmark):
raise ValueError(
f"Can only register subclasses of {type(Benchmark)}, " f"but got {cls}."
)
__registry.append(cls)


def evaluate(
include_benchmarks: Container[str] = None,
results: ResultCollection = None,
on_error="return",
) -> ResultCollection:
"""Run evaluation for all benchmarks and methods.

Note that in order for your custom benchmark to be included during
evaluation, the following conditions need to be met:

- The benchmark subclassed one of the benchmark definitions in
in ``benchmark.benchmarks``
- The benchmark is registered by applying the ``@benchmark.register``
decorator to the class
- The benchmark was imported. This is done automatically for all
benchmarks that are defined in submodules or subpackages of the
``benchmark.submissions`` module. For all other locations, make
sure to manually import the packages **before** calling the
``evaluate()`` function.

Args:
include_benchmarks:
If ``None``, run all benchmarks that were discovered. If a container
is passed, only include methods that were defined on benchmarks with
the specified names. E.g., ``include_benchmarks = ["trimouse"]`` would
only evaluate methods of the trimouse benchmark dataset.
on_error:
see documentation in ``benchmark.base.Benchmark.evaluate()``

Returns:
A collection of all results, which can be printed or exported to
``pd.DataFrame`` or ``json`` file formats.
"""
if results is None:
results = ResultCollection()
for benchmark_cls in __registry:
if include_benchmarks is not None:
if benchmark_cls.name not in include_benchmarks:
continue
benchmark = benchmark_cls()
for name in benchmark.names():
if Result(method_name=name, benchmark_name=benchmark_cls.name) in results:
continue
else:
result = benchmark.evaluate(name, on_error=on_error)
results.add(result)
return results


def get_filepath(basename: str):
return os.path.join(DATA_ROOT, basename)


def savecache(results: ResultCollection):
with open(CACHE, "w") as fh:
json.dump(results.todicts(), fh, indent=2)


def loadcache(
cache=CACHE, on_missing: Literal["raise", "ignore"] = "ignore"
) -> ResultCollection:
if not os.path.exists(cache):
if on_missing == "raise":
raise FileNotFoundError(cache)
return ResultCollection()
with open(cache, "r") as fh:
try:
data = json.load(fh)
except json.decoder.JSONDecodeError as e:
if on_missing == "raise":
raise e
return ResultCollection()
return ResultCollection.fromdicts(data)
12 changes: 12 additions & 0 deletions deeplabcut/benchmark/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# DeepLabCut2.0 Toolbox (deeplabcut.org)
# © A. & M. Mathis Labs
# https://github.com/AlexEMG/DeepLabCut
# Please see AUTHORS for contributors.
#
# https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
# Licensed under GNU Lesser General Public License v3.0

from deeplabcut.benchmark.cli import main

if __name__ == "__main__":
main()
211 changes: 211 additions & 0 deletions deeplabcut/benchmark/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# DeepLabCut2.0 Toolbox (deeplabcut.org)
# © A. & M. Mathis Labs
# https://github.com/AlexEMG/DeepLabCut
# Please see AUTHORS for contributors.
#
# https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
# Licensed under GNU Lesser General Public License v3.0

"""Base classes for benchmark and result definition

Benchmarks subclass the abstract ``Benchmark`` class and are defined by ``name``, their
``keypoints`` names, as well as groundtruth and metadata necessary to run evaluation.
Right now, the metrics to compute and report for each of the multi-animal benchmarks is the
root mean-squared-error (RMSE) and the mean average precision (mAP).

Note for contributors: If you decide to contribute a benchmark which does not fit
into this evaluation framework, please feel free to extend the base classes
(e.g. to support additional metrics).
"""

import abc
import dataclasses
from typing import Iterable
from typing import Tuple

import pandas as pd

import deeplabcut.benchmark.metrics
from deeplabcut import __version__


class BenchmarkEvaluationError(RuntimeError):
pass


class Benchmark(abc.ABC):
"""Abstract benchmark baseclass.

All benchmarks should subclass this class.
"""

@abc.abstractmethod
def names(self):
"""A unique key to describe this submission, e.g. the model name.

This is also the name that will later appear in the benchmark table.
The name needs to be unique across the whole benchmark. Non-unique names
will raise an error during submission of a PR.
"""
raise NotImplementedError()

@abc.abstractmethod
def get_predictions(self):
"""Return predictions for all images in the benchmark."""
raise NotImplementedError()

def __init__(self):
keys = ["name", "keypoints", "ground_truth", "metadata"]
for key in keys:
if not hasattr(self, key):
raise NotImplementedError(
f"Subclass of abstract Benchmark class need "
f"to define the {key} property."
)

def compute_pose_rmse(self, results_objects):
return deeplabcut.benchmark.metrics.calc_rmse_from_obj(
results_objects, h5_file=self.ground_truth, metadata_file=self.metadata
)

def compute_pose_map(self, results_objects):
return deeplabcut.benchmark.metrics.calc_map_from_obj(
results_objects, h5_file=self.ground_truth, metadata_file=self.metadata
)

def evaluate(self, name: str, on_error="raise"):
"""Evaluate this benchmark with all registered methods."""

if name not in self.names():
raise ValueError(
f"{name} is not registered. Valid names are {self.names()}"
)
if on_error not in ("ignore", "return", "raise"):
raise ValueError(f"on_error got an undefined value: {on_error}")
mean_avg_precision = float("nan")
root_mean_squared_error = float("nan")
try:
predictions = self.get_predictions(name)
mean_avg_precision = self.compute_pose_map(predictions)
root_mean_squared_error = self.compute_pose_rmse(predictions)
except Exception as exception:
if on_error == "ignore":
# ignore the exception and continue with the next evaluation, without
# yielding a result value.
return
elif on_error == "return":
# return the result value, with NaN as the result for all metrics that
# could not be computed due to the error.
pass
elif on_error == "raise":
# raise the error and stop evaluation
raise BenchmarkEvaluationError(
f"Error during benchmark evaluation for model {name}"
) from exception
else:
raise NotImplementedError() from exception
return Result(
method_name=name,
benchmark_name=self.name,
mean_avg_precision=mean_avg_precision,
root_mean_squared_error=root_mean_squared_error,
)


@dataclasses.dataclass
class Result:
"""Benchmark result."""

method_name: str
benchmark_name: str
root_mean_squared_error: float = float("nan")
mean_avg_precision: float = float("nan")
benchmark_version: str = __version__

_export_mapping = dict(
benchmark_name="benchmark",
method_name="method",
benchmark_version="version",
root_mean_squared_error="RMSE",
mean_avg_precision="mAP",
)

_primary_key = ("benchmark_name", "method_name", "benchmark_version")

@property
def primary_key(self) -> Tuple[str]:
"""The primary key to uniquely identify this result."""
return tuple(getattr(self, k) for k in self._primary_key)

@property
def primary_key_names(self) -> Tuple[str]:
"""Names of the primary keys"""
return tuple(self._export_mapping.get(k) for k in self._primary_key)

def __str__(self):
return (
f"{self.method_name}, {self.benchmark_name}: "
f"{self.mean_avg_precision} mAP, "
f"{self.root_mean_squared_error} RMSE"
)

@classmethod
def fromdict(cls, data: dict):
"""Construct result object from dictionary."""
kwargs = {attr: data[key] for attr, key in cls._export_mapping.items()}
return cls(**kwargs)

def todict(self) -> dict:
"""Export result object to dictionary, with less verbose key names."""
return {key: getattr(self, attr) for attr, key in self._export_mapping.items()}


class ResultCollection:
def __init__(self, *results):
self.results = {result.primary_key: result for result in results}

@property
def primary_key_names(self):
return next(iter(self.results.values())).primary_key_names

def toframe(self) -> pd.DataFrame:
"""Convert results to pandas dataframe"""
return pd.DataFrame(
[result.todict() for result in self.results.values()]
).set_index(list(self.primary_key_names))

def add(self, result: Result):
"""Add a result to the collection."""
if result.primary_key in self.results:
raise ValueError(
"An entry for {result.primary_key} does already "
"exist in this collection. Did you try to add the "
"same result twice?"
)
if len(self) > 0:
if result.primary_key_names != self.primary_key_names:
raise ValueError("Incompatible result format.")
self.results[result.primary_key] = result

@classmethod
def fromdicts(cls, data: Iterable[dict]):
return cls(*[Result.fromdict(entry) for entry in data])

def todicts(self):
return [result.todict() for result in self.results.values()]

def __len__(self):
return len(self.results)

def __contains__(self, other: Result):
if not isinstance(other, Result):
raise ValueError(
f"{type(self)} can only store objects of type Result, "
f"but got {type(other)}."
)
return other.primary_key in self.results

def __eq__(self, other):
if not isinstance(other, ResultCollection):
return False
return other.results == self.results
Loading