Skip to content

Commit 8eb827b

Browse files
authored
Merge pull request roboflow#437 from ashishdatta/from_roboflow_tracker_id
Update sv.Detections.from_roboflow to extract tracker_id values from inference response
2 parents 33e43bc + 6ef7bf3 commit 8eb827b

3 files changed

Lines changed: 27 additions & 8 deletions

File tree

supervision/detection/core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def from_roboflow(cls, roboflow_result: dict) -> Detections:
426426
>>> detections = sv.Detections.from_roboflow(roboflow_result)
427427
```
428428
"""
429-
xyxy, confidence, class_id, masks = process_roboflow_result(
429+
xyxy, confidence, class_id, masks, trackers = process_roboflow_result(
430430
roboflow_result=roboflow_result
431431
)
432432

@@ -438,6 +438,7 @@ def from_roboflow(cls, roboflow_result: dict) -> Detections:
438438
confidence=confidence,
439439
class_id=class_id,
440440
mask=masks,
441+
tracker_id=trackers,
441442
)
442443

443444
@classmethod

supervision/detection/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,15 @@ def extract_ultralytics_masks(yolov8_results) -> Optional[np.ndarray]:
333333

334334
def process_roboflow_result(
335335
roboflow_result: dict,
336-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
336+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray]:
337337
if not roboflow_result["predictions"]:
338-
return np.empty((0, 4)), np.empty(0), np.empty(0), None
338+
return np.empty((0, 4)), np.empty(0), np.empty(0), None, None
339339

340340
xyxy = []
341341
confidence = []
342342
class_id = []
343343
masks = []
344+
tracker_ids = []
344345

345346
image_width = int(roboflow_result["image"]["width"])
346347
image_height = int(roboflow_result["image"]["height"])
@@ -359,6 +360,8 @@ def process_roboflow_result(
359360
xyxy.append([x_min, y_min, x_max, y_max])
360361
class_id.append(prediction["class_id"])
361362
confidence.append(prediction["confidence"])
363+
if "tracker_id" in prediction:
364+
tracker_ids.append(prediction["tracker_id"])
362365
elif len(prediction["points"]) >= 3:
363366
polygon = np.array(
364367
[[point["x"], point["y"]] for point in prediction["points"]], dtype=int
@@ -368,13 +371,16 @@ def process_roboflow_result(
368371
class_id.append(prediction["class_id"])
369372
confidence.append(prediction["confidence"])
370373
masks.append(mask)
374+
if "tracker_id" in prediction:
375+
tracker_ids.append(prediction["tracker_id"])
371376

372377
xyxy = np.array(xyxy) if len(xyxy) > 0 else np.empty((0, 4))
373378
confidence = np.array(confidence) if len(confidence) > 0 else np.empty(0)
374379
class_id = np.array(class_id).astype(int) if len(class_id) > 0 else np.empty(0)
375380
masks = np.array(masks, dtype=bool) if len(masks) > 0 else None
381+
tracker_id = np.array(tracker_ids).astype(int) if len(tracker_ids) > 0 else None
376382

377-
return xyxy, confidence, class_id, masks
383+
return xyxy, confidence, class_id, masks, tracker_id
378384

379385

380386
def move_boxes(xyxy: np.ndarray, offset: np.ndarray) -> np.ndarray:

test/detection/test_utils.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def test_filter_polygons_by_area(
264264
[
265265
(
266266
{"predictions": [], "image": {"width": 1000, "height": 1000}},
267-
(np.empty((0, 4)), np.empty(0), np.empty(0), None),
267+
(np.empty((0, 4)), np.empty(0), np.empty(0), None, None),
268268
DoesNotRaise(),
269269
), # empty result
270270
(
@@ -287,6 +287,7 @@ def test_filter_polygons_by_area(
287287
np.array([0.9]),
288288
np.array([0]),
289289
None,
290+
None,
290291
),
291292
DoesNotRaise(),
292293
), # single correct object detection result
@@ -301,6 +302,7 @@ def test_filter_polygons_by_area(
301302
"confidence": 0.9,
302303
"class_id": 0,
303304
"class": "person",
305+
"tracker_id": 1,
304306
},
305307
{
306308
"x": 500.0,
@@ -310,6 +312,7 @@ def test_filter_polygons_by_area(
310312
"confidence": 0.8,
311313
"class_id": 7,
312314
"class": "truck",
315+
"tracker_id": 2,
313316
},
314317
],
315318
"image": {"width": 1000, "height": 1000},
@@ -319,6 +322,7 @@ def test_filter_polygons_by_area(
319322
np.array([0.9, 0.8]),
320323
np.array([0, 7]),
321324
None,
325+
np.array([1, 2]),
322326
),
323327
DoesNotRaise(),
324328
), # two correct object detection result
@@ -334,11 +338,12 @@ def test_filter_polygons_by_area(
334338
"class_id": 0,
335339
"class": "person",
336340
"points": [],
341+
"tracker_id": None,
337342
}
338343
],
339344
"image": {"width": 1000, "height": 1000},
340345
},
341-
(np.empty((0, 4)), np.empty(0), np.empty(0), None),
346+
(np.empty((0, 4)), np.empty(0), np.empty(0), None, None),
342347
DoesNotRaise(),
343348
), # single incorrect instance segmentation result with no points
344349
(
@@ -357,7 +362,7 @@ def test_filter_polygons_by_area(
357362
],
358363
"image": {"width": 1000, "height": 1000},
359364
},
360-
(np.empty((0, 4)), np.empty(0), np.empty(0), None),
365+
(np.empty((0, 4)), np.empty(0), np.empty(0), None, None),
361366
DoesNotRaise(),
362367
), # single incorrect instance segmentation result with no enough points
363368
(
@@ -386,6 +391,7 @@ def test_filter_polygons_by_area(
386391
np.array([0.9]),
387392
np.array([0]),
388393
TEST_MASK,
394+
None,
389395
),
390396
DoesNotRaise(),
391397
), # single incorrect instance segmentation result with no enough points
@@ -425,14 +431,17 @@ def test_filter_polygons_by_area(
425431
np.array([0.9]),
426432
np.array([0]),
427433
TEST_MASK,
434+
None,
428435
),
429436
DoesNotRaise(),
430437
), # two instance segmentation results - one correct, one incorrect
431438
],
432439
)
433440
def test_process_roboflow_result(
434441
roboflow_result: dict,
435-
expected_result: Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]],
442+
expected_result: Tuple[
443+
np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray
444+
],
436445
exception: Exception,
437446
) -> None:
438447
with exception:
@@ -443,6 +452,9 @@ def test_process_roboflow_result(
443452
assert (result[3] is None and expected_result[3] is None) or (
444453
np.array_equal(result[3], expected_result[3])
445454
)
455+
assert (result[4] is None and expected_result[4] is None) or (
456+
np.array_equal(result[4], expected_result[4])
457+
)
446458

447459

448460
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)