Skip to content

Commit 8bca45b

Browse files
style and doc change as requested
1 parent a550aaa commit 8bca45b

1 file changed

Lines changed: 10 additions & 9 deletions

File tree

  • supervision/tracker/byte_tracker

supervision/tracker/byte_tracker/core.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple
1+
from typing import List, Tuple, Optional
22

33
import numpy as np
44

@@ -11,7 +11,7 @@
1111
class STrack(BaseTrack):
1212
shared_kalman = KalmanFilter()
1313

14-
def __init__(self, tlwh, score, class_ids, mask=None):
14+
def __init__(self, tlwh, score, class_ids, mask:Optional[np.array] = None):
1515
# wait activate
1616
self._tlwh = np.asarray(tlwh, dtype=np.float32)
1717
self.kalman_filter = None
@@ -259,13 +259,14 @@ def update_with_detections(self, detections: Detections) -> Detections:
259259
return detections
260260

261261
def update_with_tensors(
262-
self, tensors: np.ndarray, masks: np.ndarray = None
262+
self, tensors: np.ndarray, masks: Optional[np.array] = None
263263
) -> List[STrack]:
264264
"""
265265
Updates the tracker with the provided tensors and returns the updated tracks.
266266
267267
Parameters:
268268
tensors: The new tensors to update with.
269+
masks: The new masks associated to new tensors
269270
270271
Returns:
271272
List[STrack]: Updated tracks.
@@ -288,11 +289,11 @@ def update_with_tensors(
288289
dets_second = bboxes[inds_second]
289290
dets = bboxes[remain_inds]
290291
if masks is not None:
291-
masks_hs = masks[remain_inds]
292-
masks_ls = masks[inds_second]
292+
masks_keep = masks[remain_inds]
293+
masks_second = masks[inds_second]
293294
else:
294-
masks_hs = np.array([None] * len(remain_inds))
295-
masks_ls = np.array([None] * len(inds_second))
295+
masks_keep = np.array([None] * len(remain_inds))
296+
masks_second = np.array([None] * len(inds_second))
296297
scores_keep = scores[remain_inds]
297298
scores_second = scores[inds_second]
298299

@@ -303,7 +304,7 @@ def update_with_tensors(
303304
"""Detections"""
304305
detections = [
305306
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, m)
306-
for (tlbr, s, c, m) in zip(dets, scores_keep, class_ids_keep, masks_hs)
307+
for (tlbr, s, c, m) in zip(dets, scores_keep, class_ids_keep, masks_keep)
307308
]
308309
else:
309310
detections = []
@@ -345,7 +346,7 @@ def update_with_tensors(
345346
detections_second = [
346347
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, m)
347348
for (tlbr, s, c, m) in zip(
348-
dets_second, scores_second, class_ids_second, masks_ls
349+
dets_second, scores_second, class_ids_second, masks_second
349350
)
350351
]
351352
else:

0 commit comments

Comments
 (0)