diff --git a/deeplabcut/gui/widgets.py b/deeplabcut/gui/widgets.py index a0bd2913b..855f69400 100644 --- a/deeplabcut/gui/widgets.py +++ b/deeplabcut/gui/widgets.py @@ -10,30 +10,20 @@ # import ast import os -import warnings from queue import Queue -import matplotlib.colors as mcolors import napari import numpy as np -import pandas as pd -from matplotlib.backends.backend_qt5agg import ( - FigureCanvasQTAgg as FigureCanvas, -) -from matplotlib.backends.backend_qt5agg import ( - NavigationToolbar2QT, -) -from matplotlib.collections import LineCollection +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT from matplotlib.figure import Figure -from matplotlib.path import Path from matplotlib.widgets import Button, LassoSelector, RectangleSelector from PySide6 import QtCore, QtWidgets from PySide6.QtGui import QAction, QCursor, QStandardItem, QStandardItemModel -from scipy.spatial import cKDTree as KDTree -from skimage import io from deeplabcut.utils import auxiliaryfunctions from deeplabcut.utils.auxfun_videos import VideoWriter +from deeplabcut.utils.skeleton import SkeletonBuilder as BaseSkeletonBuilder def launch_napari(files=None, plugin="napari-deeplabcut", stack=False): @@ -529,91 +519,34 @@ def display_help(self, *args): ) -class SkeletonBuilder(QtWidgets.QDialog): +class SkeletonBuilder(QtWidgets.QDialog, BaseSkeletonBuilder): def __init__(self, config_path, parent=None): - super().__init__(parent) - self.config_path = config_path - self.cfg = auxiliaryfunctions.read_config(config_path) - # Find uncropped labeled data - self.df = None - found = False - root = os.path.join(self.cfg["project_path"], "labeled-data") - for dir_ in os.listdir(root): - folder = os.path.join(root, dir_) - if os.path.isdir(folder) and not any(folder.endswith(s) for s in ("cropped", "labeled")): - self.df = pd.read_hdf(os.path.join(folder, f"CollectedData_{self.cfg['scorer']}.h5")) - row, col = self.pick_labeled_frame() - if "individuals" in self.df.columns.names: - self.df = self.df.xs(col, axis=1, level="individuals") - self.xy = self.df.loc[row].values.reshape((-1, 2)) - missing = np.flatnonzero(np.isnan(self.xy).all(axis=1)) - if not missing.size: - found = True - break - if self.df is None: - raise OSError("No labeled data were found.") - - self.bpts = self.df.columns.get_level_values("bodyparts").unique() - if not found: - warnings.warn( - f"A fully labeled animal could not be found. " - f"{', '.join(self.bpts[missing])} will need to be manually connected in the config.yaml.", - stacklevel=2, - ) - self.tree = KDTree(self.xy) - # Handle image previously annotated on a different platform - if isinstance(row, str): - sep = "/" if "/" in row else "\\" - row = row.split(sep) - self.image = io.imread(os.path.join(self.cfg["project_path"], *row)) - self.inds = set() - self.segs = set() - # Draw the skeleton if already existent - if self.cfg["skeleton"]: - for bone in self.cfg["skeleton"]: - pair = np.flatnonzero(self.bpts.isin(bone)) - if len(pair) != 2: - continue - pair_sorted = tuple(sorted(pair)) - self.inds.add(pair_sorted) - self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :]))) + QtWidgets.QDialog.__init__(self, parent) + self._parent = parent + self.setWindowTitle("Skeleton Builder") + BaseSkeletonBuilder.__init__(self, config_path) + def build_ui(self): self.fig = Figure() self.ax = self.fig.add_subplot(111) self.ax.axis("off") + ax_clear = self.fig.add_axes([0.85, 0.55, 0.1, 0.1]) ax_export = self.fig.add_axes([0.85, 0.45, 0.1, 0.1]) + self.clear_button = Button(ax_clear, "Clear") self.clear_button.on_clicked(self.clear) + self.export_button = Button(ax_export, "Export") self.export_button.on_clicked(self.export) + self.fig.canvas.mpl_connect("pick_event", self.on_pick) + self.canvas = FigureCanvas(self.fig) layout = QtWidgets.QVBoxLayout(self) layout.addWidget(self.canvas) self.setLayout(layout) - self.lines = LineCollection(self.segs, colors=mcolors.to_rgba(self.cfg["skeleton_color"])) - self.lines.set_picker(True) - self._show() - - def pick_labeled_frame(self): - # Find the most 'complete' animal - try: - count = self.df.groupby(level="individuals", axis=1).count() - if "single" in count: - count.drop("single", axis=1, inplace=True) - except KeyError: - count = self.df.count(axis=1).to_frame() - mask = count.where(count == count.values.max()) - kept = mask.stack().index.to_list() - np.random.shuffle(kept) - picked = kept.pop() - row = picked[:-1] - col = picked[-1] - return row, col - - def _show(self): lo = np.nanmin(self.xy, axis=0) hi = np.nanmax(self.xy, axis=0) center = (hi + lo) / 2 @@ -621,6 +554,7 @@ def _show(self): ampl = 1.3 w *= ampl h *= ampl + self.ax.set_xlim(center[0] - w / 2, center[0] + w / 2) self.ax.set_ylim(center[1] - h / 2, center[1] + h / 2) self.ax.imshow(self.image) @@ -629,41 +563,18 @@ def _show(self): self.ax.invert_yaxis() self.lasso = LassoSelector(self.ax, onselect=self.on_select) - self.show() + self.canvas.draw_idle() - def clear(self, *args): - self.inds.clear() - self.segs.clear() - self.lines.set_segments(self.segs) - - def export(self, *args): - inds_flat = set(ind for pair in self.inds for ind in pair) - unconnected = [i for i in range(len(self.xy)) if i not in inds_flat] - if len(unconnected): - warnings.warn( - "You didn't connect all the bodyparts (which is fine!). This is just a note to let you know.", - stacklevel=2, - ) - self.cfg["skeleton"] = [tuple(self.bpts[list(pair)]) for pair in self.inds] - auxiliaryfunctions.write_config(self.config_path, self.cfg) - - def on_pick(self, event): - if event.mouseevent.button == 3: - removed = event.artist.get_segments().pop(event.ind[0]) - self.segs.remove(tuple(map(tuple, removed))) - self.inds.remove(tuple(self.tree.query(removed)[1])) - - def on_select(self, verts): - self.path = Path(verts) - self.verts = verts - inds = self.tree.query_ball_point(verts, 5) - inds_unique = [] - for lst in inds: - if len(lst) and lst[0] not in inds_unique: - inds_unique.append(lst[0]) - for pair in zip(inds_unique, inds_unique[1:], strict=False): - pair_sorted = tuple(sorted(pair)) - self.inds.add(pair_sorted) - self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :]))) - self.lines.set_segments(self.segs) - self.fig.canvas.draw_idle() + def read_config(self, config_path): + return auxiliaryfunctions.read_config(config_path) + + def write_config(self, config_path, cfg): + # Normalize to plain lists before writing config.yaml + cfg = dict(cfg) + if "skeleton" in cfg: + cfg["skeleton"] = [list(pair) for pair in cfg["skeleton"]] + auxiliaryfunctions.write_config(config_path, cfg) + + def display(self): + # No-op, the dialog is shown/exec'd by the caller + pass diff --git a/deeplabcut/utils/skeleton.py b/deeplabcut/utils/skeleton.py index 4469cd7a1..6021b6f9e 100644 --- a/deeplabcut/utils/skeleton.py +++ b/deeplabcut/utils/skeleton.py @@ -26,23 +26,26 @@ import numpy as np import pandas as pd from matplotlib.collections import LineCollection -from matplotlib.path import Path from matplotlib.widgets import Button, LassoSelector from ruamel.yaml import YAML -from scipy.spatial import cKDTree as KDTree +from scipy.spatial import KDTree from skimage import io +# NOTE @C-Achard 2026-03-26 duplicate config read/write functions +# should be addressed in config refactor def read_config(configname): if not os.path.exists(configname): raise FileNotFoundError(f"Config {configname} is not found. Please make sure that the file exists.") - with open(configname) as file: - return YAML().load(file) + yaml = YAML(typ="rt") + with open(configname, encoding="utf-8") as file: + return yaml.load(file) def write_config(configname, cfg): - with open(configname, "w") as file: - YAML().dump(cfg, file) + yaml = YAML(typ="rt") + with open(configname, "w", encoding="utf-8") as file: + yaml.dump(cfg, file) class SkeletonBuilder: @@ -94,25 +97,10 @@ def __init__(self, config_path): self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :]))) self.lines = LineCollection(self.segs, colors=mcolors.to_rgba(self.cfg["skeleton_color"])) self.lines.set_picker(True) - self.show() + self.build_ui() + self.display() - def pick_labeled_frame(self): - # Find the most 'complete' animal - try: - count = self.df.groupby(level="individuals", axis=1).count() - if "single" in count: - count.drop("single", axis=1, inplace=True) - except KeyError: - count = self.df.count(axis=1).to_frame() - mask = count.where(count == count.values.max()) - kept = mask.stack().index.to_list() - np.random.shuffle(kept) - picked = kept.pop() - row = picked[:-1] - col = picked[-1] - return row, col - - def show(self): + def build_ui(self): self.fig = plt.figure() ax = self.fig.add_subplot(111) ax.axis("off") @@ -138,12 +126,37 @@ def show(self): self.export_button = Button(ax_export, "Export") self.export_button.on_clicked(self.export) self.fig.canvas.mpl_connect("pick_event", self.on_pick) + + def display(self): plt.show() + def pick_labeled_frame(self): + # Find the most 'complete' animal + if "individuals" in self.df.columns.names: + count = self.df.T.groupby(level="individuals").count().T + if "single" in count.columns: + count = count.drop(columns="single") + else: + count = self.df.count(axis=1).to_frame() + mask = count.where(count == count.to_numpy().max()) + kept = mask.stack().index.to_list() + np.random.shuffle(kept) + picked = kept.pop() + row = picked[:-1] + col = picked[-1] + return row, col + def clear(self, *args): self.inds.clear() self.segs.clear() - self.lines.set_segments(self.segs) + self.lines.set_segments([]) + self.fig.canvas.draw_idle() + + def read_config(self, config_path): + return read_config(config_path) + + def write_config(self, config_path, cfg): + write_config(config_path, cfg) def export(self, *args): inds_flat = set(ind for pair in self.inds for ind in pair) @@ -153,18 +166,24 @@ def export(self, *args): "You didn't connect all the bodyparts (which is fine!). This is just a note to let you know.", stacklevel=2, ) - self.cfg["skeleton"] = [tuple(self.bpts[list(pair)]) for pair in self.inds] - write_config(self.config_path, self.cfg) + # sort to ensure consistent order in config.yaml + self.cfg["skeleton"] = [tuple(self.bpts[list(pair)]) for pair in sorted(self.inds)] + self.write_config(self.config_path, self.cfg) def on_pick(self, event): if event.mouseevent.button == 3: - removed = event.artist.get_segments().pop(event.ind[0]) - self.segs.remove(tuple(map(tuple, removed))) - self.inds.remove(tuple(self.tree.query(removed)[1])) + seg = tuple(map(tuple, event.artist.get_segments()[event.ind[0]])) + self.segs.discard(seg) + + pair = tuple(sorted(self.tree.query(np.asarray(seg))[1])) + self.inds.discard(pair) + + self.lines.set_segments(list(self.segs)) + self.fig.canvas.draw_idle() def on_select(self, verts): - self.path = Path(verts) - self.verts = verts + # self.path = Path(verts) + # self.verts = verts inds = self.tree.query_ball_point(verts, 5) inds_unique = [] for lst in inds: diff --git a/tests/utils/test_skeleton.py b/tests/utils/test_skeleton.py new file mode 100644 index 000000000..c83d5060d --- /dev/null +++ b/tests/utils/test_skeleton.py @@ -0,0 +1,374 @@ +import warnings +from types import SimpleNamespace + +import matplotlib + +matplotlib.use("Agg", force=True) + +import numpy as np +import pandas as pd +import pytest +from matplotlib.collections import LineCollection +from matplotlib.figure import Figure +from scipy.spatial import KDTree + +from deeplabcut.utils import skeleton as skeleton_mod +from deeplabcut.utils.skeleton import SkeletonBuilder, write_config + +# --------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------- + + +def make_config(project_path, scorer="TestScorer", skeleton=None): + return { + "project_path": str(project_path), + "scorer": scorer, + "skeleton": skeleton or [], + "skeleton_color": "red", + "dotsize": 4, + } + + +def make_test_builder(): + """ + Construct a SkeletonBuilder instance without calling __init__, + so individual methods can be unit-tested in isolation. + """ + builder = SkeletonBuilder.__new__(SkeletonBuilder) + return builder + + +def attach_fake_canvas(builder): + builder.fig = Figure() + builder.fig.canvas.draw_idle = lambda: None + + +# --------------------------------------------------------------------- +# pick_labeled_frame +# --------------------------------------------------------------------- + + +def test_pick_labeled_frame_multi_animal_drops_single(monkeypatch): + builder = make_test_builder() + + index = pd.MultiIndex.from_tuples( + [("labeled-data/session1", "img001.png")], + names=["folder", "image"], + ) + columns = pd.MultiIndex.from_product( + [["TestScorer"], ["single", "mouseA"], ["nose", "tail"], ["x", "y"]], + names=["scorer", "individuals", "bodyparts", "coords"], + ) + + # "single" is fully labeled too, but should be dropped before choosing. + row = [ + 1.0, + 2.0, + 3.0, + 4.0, # single + 10.0, + 20.0, + 30.0, + 40.0, # mouseA + ] + builder.df = pd.DataFrame([row], index=index, columns=columns) + + monkeypatch.setattr(np.random, "shuffle", lambda x: None) + + picked_row, picked_col = builder.pick_labeled_frame() + + assert picked_row == ("labeled-data/session1", "img001.png") + assert picked_col == "mouseA" + + +def test_pick_labeled_frame_without_individuals(monkeypatch): + builder = make_test_builder() + + index = pd.MultiIndex.from_tuples( + [("labeled-data/session1", "img001.png")], + names=["folder", "image"], + ) + columns = pd.MultiIndex.from_product( + [["TestScorer"], ["nose", "tail"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + + builder.df = pd.DataFrame( + [[1.0, 2.0, 3.0, 4.0]], + index=index, + columns=columns, + ) + + monkeypatch.setattr(np.random, "shuffle", lambda x: None) + + picked_row, picked_col = builder.pick_labeled_frame() + + assert picked_row == ("labeled-data/session1", "img001.png") + # fallback path uses count(...).to_frame(), so the single column is usually 0 + assert picked_col == 0 + + +# --------------------------------------------------------------------- +# clear +# --------------------------------------------------------------------- + + +def test_clear_resets_indices_segments_and_linecollection(): + builder = make_test_builder() + builder.inds = {(0, 1), (1, 2)} + builder.segs = { + ((0.0, 0.0), (10.0, 0.0)), + ((10.0, 0.0), (20.0, 0.0)), + } + builder.lines = LineCollection([np.array([[0.0, 0.0], [10.0, 0.0]]), np.array([[10.0, 0.0], [20.0, 0.0]])]) + attach_fake_canvas(builder) + + builder.clear() + + assert builder.inds == set() + assert builder.segs == set() + assert list(builder.lines.get_segments()) == [] + + +# --------------------------------------------------------------------- +# export +# --------------------------------------------------------------------- + + +def test_export_sorts_pairs_and_warns_for_unconnected(monkeypatch): + builder = make_test_builder() + builder.config_path = "dummy_config.yaml" + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + [20.0, 0.0], + [30.0, 0.0], # intentionally left unconnected + ] + ) + builder.bpts = pd.Index(["nose", "tail", "paw", "ear"], name="bodyparts") + builder.inds = {(1, 2), (0, 1)} # intentionally unordered + builder.cfg = {"skeleton": []} + + captured = {} + + def fake_write_config(path, cfg): + captured["path"] = path + captured["cfg"] = cfg.copy() + + monkeypatch.setattr(skeleton_mod, "write_config", fake_write_config) + + with pytest.warns(UserWarning, match="didn't connect all the bodyparts"): + builder.export() + + assert captured["path"] == "dummy_config.yaml" + assert captured["cfg"]["skeleton"] == [ + ("nose", "tail"), + ("tail", "paw"), + ] + + +def test_export_without_warning_when_all_bodyparts_connected(monkeypatch): + builder = make_test_builder() + builder.config_path = "dummy_config.yaml" + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + [20.0, 0.0], + ] + ) + builder.bpts = pd.Index(["nose", "tail", "paw"], name="bodyparts") + builder.inds = {(0, 1), (1, 2)} + builder.cfg = {"skeleton": []} + + monkeypatch.setattr(skeleton_mod, "write_config", lambda path, cfg: None) + + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + builder.export() + + assert not any("didn't connect all the bodyparts" in str(w.message) for w in record) + assert builder.cfg["skeleton"] == [ + ("nose", "tail"), + ("tail", "paw"), + ] + + +# --------------------------------------------------------------------- +# on_select +# --------------------------------------------------------------------- + + +def test_on_select_adds_pairs_segments_and_updates_canvas(): + builder = make_test_builder() + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + [20.0, 0.0], + ] + ) + builder.tree = KDTree(builder.xy) + builder.inds = set() + builder.segs = set() + builder.lines = LineCollection([]) + attach_fake_canvas(builder) + + verts = [(0.0, 0.0), (10.0, 0.0), (20.0, 0.0)] + builder.on_select(verts) + + assert builder.inds == {(0, 1), (1, 2)} + assert ((0.0, 0.0), (10.0, 0.0)) in builder.segs + assert ((10.0, 0.0), (20.0, 0.0)) in builder.segs + assert len(builder.lines.get_segments()) == 2 + + +def test_on_select_ignores_duplicate_hits(): + builder = make_test_builder() + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + [20.0, 0.0], + ] + ) + builder.tree = KDTree(builder.xy) + builder.inds = set() + builder.segs = set() + builder.lines = LineCollection([]) + attach_fake_canvas(builder) + + # Repeated nearby vertices should not create duplicate pairs + verts = [(0.0, 0.0), (0.1, 0.0), (10.0, 0.0), (10.1, 0.0), (20.0, 0.0)] + builder.on_select(verts) + + assert builder.inds == {(0, 1), (1, 2)} + assert len(builder.segs) == 2 + + +# --------------------------------------------------------------------- +# on_pick +# --------------------------------------------------------------------- + + +def test_on_pick_right_click_removes_segment_and_pair(): + builder = make_test_builder() + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + ] + ) + builder.tree = KDTree(builder.xy) + builder.inds = {(0, 1)} + builder.segs = {((0.0, 0.0), (10.0, 0.0))} + builder.lines = LineCollection([np.array([[0.0, 0.0], [10.0, 0.0]])]) + attach_fake_canvas(builder) + + event = SimpleNamespace( + mouseevent=SimpleNamespace(button=3), + artist=builder.lines, + ind=[0], + ) + + builder.on_pick(event) + + assert builder.inds == set() + assert builder.segs == set() + assert list(builder.lines.get_segments()) == [] + + +def test_on_pick_non_right_click_does_nothing(): + builder = make_test_builder() + builder.xy = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], + ] + ) + builder.tree = KDTree(builder.xy) + builder.inds = {(0, 1)} + builder.segs = {((0.0, 0.0), (10.0, 0.0))} + builder.lines = LineCollection([np.array([[0.0, 0.0], [10.0, 0.0]])]) + attach_fake_canvas(builder) + + event = SimpleNamespace( + mouseevent=SimpleNamespace(button=1), + artist=builder.lines, + ind=[0], + ) + + builder.on_pick(event) + + assert builder.inds == {(0, 1)} + assert builder.segs == {((0.0, 0.0), (10.0, 0.0))} + assert len(builder.lines.get_segments()) == 1 + + +# --------------------------------------------------------------------- +# __init__ lightweight integration +# --------------------------------------------------------------------- + + +def test_init_loads_dataframe_image_and_existing_skeleton(tmp_path, monkeypatch): + project_path = tmp_path / "project" + labeled_data = project_path / "labeled-data" / "session1" + labeled_data.mkdir(parents=True) + + cfg_path = project_path / "config.yaml" + cfg = make_config( + project_path=project_path, + scorer="TestScorer", + skeleton=[ + ["nose", "tail"], + ["missing", "nose"], + ], # second pair should be ignored + ) + write_config(cfg_path, cfg) + + index = pd.MultiIndex.from_tuples( + [("labeled-data/session1", "img001.png")], + names=["folder", "image"], + ) + columns = pd.MultiIndex.from_product( + [["TestScorer"], ["nose", "tail"], ["x", "y"]], + names=["scorer", "bodyparts", "coords"], + ) + df = pd.DataFrame( + [[0.0, 0.0, 10.0, 0.0]], + index=index, + columns=columns, + ) + h5_path = labeled_data / "CollectedData_TestScorer.h5" + df.to_hdf(h5_path, key="df", mode="w") + + monkeypatch.setattr(skeleton_mod.io, "imread", lambda path: np.zeros((5, 5, 3), dtype=np.uint8)) + monkeypatch.setattr(SkeletonBuilder, "build_ui", lambda self: None) + monkeypatch.setattr(SkeletonBuilder, "display", lambda self: None) + monkeypatch.setattr(np.random, "shuffle", lambda x: None) + + builder = SkeletonBuilder(str(cfg_path)) + + assert builder.config_path == str(cfg_path) + assert list(builder.bpts) == ["nose", "tail"] + assert builder.xy.shape == (2, 2) + assert builder.image.shape == (5, 5, 3) + assert builder.inds == {(0, 1)} + assert ((0.0, 0.0), (10.0, 0.0)) in builder.segs + + +def test_init_raises_if_no_labeled_data_found(tmp_path, monkeypatch): + project_path = tmp_path / "project" + (project_path / "labeled-data").mkdir(parents=True) + + cfg_path = project_path / "config.yaml" + cfg = make_config(project_path=project_path, scorer="TestScorer") + write_config(cfg_path, cfg) + + monkeypatch.setattr(SkeletonBuilder, "build_ui", lambda self: None) + monkeypatch.setattr(SkeletonBuilder, "display", lambda self: None) + + with pytest.raises(IOError, match="No labeled data were found"): + SkeletonBuilder(str(cfg_path))