Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions deeplabcut/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
import json
import os
from collections.abc import Container
from pathlib import Path
from typing import Literal

from deeplabcut.benchmark.base import Benchmark, Result, ResultCollection

DATA_ROOT = os.path.join(os.getcwd(), "data")
CACHE = os.path.join(os.getcwd(), ".results")
DATA_ROOT = Path.cwd() / "data"
CACHE = Path.cwd() / ".results"

__registry = []

Expand Down Expand Up @@ -99,20 +100,20 @@ def evaluate(


def get_filepath(basename: str):
return os.path.join(DATA_ROOT, basename)
return DATA_ROOT / basename


def savecache(results: ResultCollection):
with open(CACHE, "w") as fh:
with Path(CACHE).open("w") as fh:
json.dump(results.todicts(), fh, indent=2)


def loadcache(cache=CACHE, on_missing: Literal["raise", "ignore"] = "ignore") -> ResultCollection:
if not os.path.exists(cache):
if not Path(cache).exists():
if on_missing == "raise":
raise FileNotFoundError(cache)
return ResultCollection()
with open(cache) as fh:
with Path(cache).open() as fh:
try:
data = json.load(fh)
except json.decoder.JSONDecodeError as e:
Expand Down
8 changes: 4 additions & 4 deletions deeplabcut/benchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

"""Evaluation metrics for the DeepLabCut benchmark."""

import os
import pickle
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
Expand All @@ -33,7 +33,7 @@ def _format_gt_data(h5file: str, test_indices: list[int] | None = None):
except KeyError:
n_unique = 0
guarantee_multiindex_rows(df)
file_paths = [os.path.join(*row) for row in df.index.to_list()]
file_paths = [Path(*row) for row in df.index.to_list()]
temp = (
df.stack("individuals", dropna=False)
.reindex(animals, level="individuals")
Expand Down Expand Up @@ -248,13 +248,13 @@ def load_test_images(h5file: str, metadata: str) -> list[str]:
test_images = []
for img_path in df_test.index:
if not isinstance(img_path, str):
img_path = os.path.join(*img_path)
img_path = str(Path(*img_path))
test_images.append(img_path)
return test_images


def _load_test_indices(shuffle_metadata_path: str) -> list[int]:
"""Returns the indices of test images in the training dataset dataframe."""
with open(shuffle_metadata_path, "rb") as f:
with Path(shuffle_metadata_path).open("rb") as f:
test_indices = set([int(i) for i in pickle.load(f)[2]])
return list(sorted(test_indices))
3 changes: 2 additions & 1 deletion deeplabcut/benchmark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import pkgutil
import sys
from pathlib import Path


class RedirectStdStreams:
Expand Down Expand Up @@ -46,7 +47,7 @@ def __exit__(self, exc_type, exc_value, traceback):

class DisableOutput(RedirectStdStreams):
def __init__(self):
devnull = open(os.devnull, "w")
devnull = Path(os.devnull).open("w")
super().__init__(stdout=devnull, stderr=devnull)


Expand Down
Loading