Skip to content

Commit 48851b5

Browse files
authored
Merge pull request roboflow#436 from AntonioConsiglio/develop
roboflow#418 Make sv.ByteTrack work with segmentation model
2 parents e5894df + 49d268c commit 48851b5

File tree

1 file changed

+27
-9
lines changed
  • supervision/tracker/byte_tracker

1 file changed

+27
-9
lines changed

supervision/tracker/byte_tracker/core.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Optional, Tuple
22

33
import numpy as np
44

@@ -11,7 +11,7 @@
1111
class STrack(BaseTrack):
1212
shared_kalman = KalmanFilter()
1313

14-
def __init__(self, tlwh, score, class_ids):
14+
def __init__(self, tlwh, score, class_ids, mask: Optional[np.array] = None):
1515
# wait activate
1616
self._tlwh = np.asarray(tlwh, dtype=np.float32)
1717
self.kalman_filter = None
@@ -21,6 +21,7 @@ def __init__(self, tlwh, score, class_ids):
2121
self.score = score
2222
self.class_ids = class_ids
2323
self.tracklet_len = 0
24+
self.mask = mask
2425

2526
def predict(self):
2627
mean_state = self.mean.copy()
@@ -74,6 +75,7 @@ def re_activate(self, new_track, frame_id, new_id=False):
7475
if new_id:
7576
self.track_id = self.next_id()
7677
self.score = new_track.score
78+
self.mask = new_track.mask
7779

7880
def update(self, new_track, frame_id):
7981
"""
@@ -95,6 +97,8 @@ def update(self, new_track, frame_id):
9597

9698
self.score = new_track.score
9799

100+
self.mask = new_track.mask
101+
98102
@property
99103
def tlwh(self):
100104
"""Get current position in bounding box format `(top left x, top left y,
@@ -231,9 +235,8 @@ def update_with_detections(self, detections: Detections) -> Detections:
231235
... )
232236
```
233237
"""
234-
235238
tracks = self.update_with_tensors(
236-
tensors=detections2boxes(detections=detections)
239+
tensors=detections2boxes(detections=detections), masks=detections.mask
237240
)
238241
detections = Detections.empty()
239242
if len(tracks) > 0:
@@ -249,17 +252,22 @@ def update_with_detections(self, detections: Detections) -> Detections:
249252
detections.confidence = np.array(
250253
[t.score for t in tracks], dtype=np.float32
251254
)
255+
detections.mask = np.array([t.mask for t in tracks], dtype=bool)
256+
252257
else:
253258
detections.tracker_id = np.array([], dtype=int)
254259

255260
return detections
256261

257-
def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
262+
def update_with_tensors(
263+
self, tensors: np.ndarray, masks: Optional[np.array] = None
264+
) -> List[STrack]:
258265
"""
259266
Updates the tracker with the provided tensors and returns the updated tracks.
260267
261268
Parameters:
262269
tensors: The new tensors to update with.
270+
masks: The new masks associated to new tensors
263271
264272
Returns:
265273
List[STrack]: Updated tracks.
@@ -281,6 +289,12 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
281289
inds_second = np.logical_and(inds_low, inds_high)
282290
dets_second = bboxes[inds_second]
283291
dets = bboxes[remain_inds]
292+
if masks is not None:
293+
masks_keep = masks[remain_inds]
294+
masks_second = masks[inds_second]
295+
else:
296+
masks_keep = np.array([None] * len(remain_inds))
297+
masks_second = np.array([None] * len(inds_second))
284298
scores_keep = scores[remain_inds]
285299
scores_second = scores[inds_second]
286300

@@ -290,8 +304,10 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
290304
if len(dets) > 0:
291305
"""Detections"""
292306
detections = [
293-
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
294-
for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep)
307+
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, m)
308+
for (tlbr, s, c, m) in zip(
309+
dets, scores_keep, class_ids_keep, masks_keep
310+
)
295311
]
296312
else:
297313
detections = []
@@ -331,8 +347,10 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
331347
if len(dets_second) > 0:
332348
"""Detections"""
333349
detections_second = [
334-
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
335-
for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second)
350+
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, m)
351+
for (tlbr, s, c, m) in zip(
352+
dets_second, scores_second, class_ids_second, masks_second
353+
)
336354
]
337355
else:
338356
detections_second = []

0 commit comments

Comments
 (0)