From 75ecd483cd53368aa59fed2fa501960418933263 Mon Sep 17 00:00:00 2001 From: Jaap de Ruyter van Steveninck <32810691+deruyter92@users.noreply.github.com> Date: Mon, 22 Jun 2026 10:43:39 +0200 Subject: [PATCH] Fix NaN handling in PredictKeypointIdentities postprocessor --- .../data/postprocessor.py | 12 ++++-- .../data/test_postprocessor.py | 42 +++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/deeplabcut/pose_estimation_pytorch/data/postprocessor.py b/deeplabcut/pose_estimation_pytorch/data/postprocessor.py index e408b9b49..0af064e8f 100644 --- a/deeplabcut/pose_estimation_pytorch/data/postprocessor.py +++ b/deeplabcut/pose_estimation_pytorch/data/postprocessor.py @@ -504,13 +504,17 @@ def __call__(self, predictions: dict[str, np.ndarray], context: Context) -> tupl id_score_matrix = np.zeros((num_preds, num_keypoints, num_ids)) for pred_idx, individual_keypoints in enumerate(pose): - heatmap_indices = np.rint(individual_keypoints).astype(int) + xy = individual_keypoints[:, :2] + valid = np.all(np.isfinite(xy), axis=1) + + heatmap_indices = np.zeros((num_keypoints, 2), dtype=int) + if np.any(valid): + heatmap_indices[valid] = np.rint(xy[valid]).astype(int) xs = np.clip(heatmap_indices[:, 0], 0, w - 1) ys = np.clip(heatmap_indices[:, 1], 0, h - 1) - - # get the score from each identity heatmap at each predicted keypoint for kpt_idx, (x, y) in enumerate(zip(xs, ys, strict=False)): - id_score_matrix[pred_idx, kpt_idx] = identity_heatmap[y, x, :] + if valid[kpt_idx]: + id_score_matrix[pred_idx, kpt_idx] = identity_heatmap[y, x, :] predictions[self.identity_key] = id_score_matrix if not self.keep_id_maps: diff --git a/tests/pose_estimation_pytorch/data/test_postprocessor.py b/tests/pose_estimation_pytorch/data/test_postprocessor.py index 7148fdf1c..354f63348 100644 --- a/tests/pose_estimation_pytorch/data/test_postprocessor.py +++ b/tests/pose_estimation_pytorch/data/test_postprocessor.py @@ -379,3 +379,45 @@ def test_remove_low_confidence_boxes(data): np.testing.assert_array_equal(predictions["bboxes"], expected_bboxes) np.testing.assert_array_equal(predictions["bbox_scores"], expected_scores) + + +def test_predict_keypoint_identities_handles_nan_keypoints(): + import warnings + + p = PredictKeypointIdentities( + identity_key="keypoint_identity", + identity_map_key="identity_map", + pose_key="bodyparts", + keep_id_maps=True, + ) + + # PAF-style output: (num_individuals, num_bodyparts, 5); missing joint is all-NaN + bodyparts = np.array( + [ + [ + [3.1, 1.0, 0.8, 0.0, 0.5], # valid + [np.nan, np.nan, np.nan, np.nan, np.nan], # missing (assembler default) + [1.0, 0.0, 0.9, 1.0, 0.5], # valid + ], + ] + ) + id_heatmap = np.array( + [ + [[0.1, 0.1], [0.2, 0.1], [0.3, 0.1], [0.4, 0.1]], + [[0.1, 0.2], [0.2, 0.2], [0.3, 0.2], [0.4, 0.2]], + [[0.1, 0.3], [0.2, 0.3], [0.3, 0.3], [0.4, 0.3]], + [[0.1, 0.4], [0.2, 0.4], [0.3, 0.4], [0.4, 0.4]], + ] + ) + predictions_in = {"bodyparts": bodyparts, "identity_map": id_heatmap} + + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + predictions, _ = p(predictions_in, {}) + + expected = np.zeros((1, 3, 2)) + expected[0, 0] = id_heatmap[1, 3] # rint(3.1, 1.0) -> (3, 1) + expected[0, 1] = 0.0 # NaN keypoint: leave identity scores at zero + expected[0, 2] = id_heatmap[0, 1] # rint(1.0, 0.0) -> (1, 0) + + np.testing.assert_array_equal(predictions["keypoint_identity"], expected)