1- from typing import List , Tuple
1+ from typing import List , Optional , Tuple
22
33import numpy as np
44
1111class STrack (BaseTrack ):
1212 shared_kalman = KalmanFilter ()
1313
14- def __init__ (self , tlwh , score , class_ids ):
14+ def __init__ (self , tlwh , score , class_ids , mask : Optional [ np . array ] = None ):
1515 # wait activate
1616 self ._tlwh = np .asarray (tlwh , dtype = np .float32 )
1717 self .kalman_filter = None
@@ -21,6 +21,7 @@ def __init__(self, tlwh, score, class_ids):
2121 self .score = score
2222 self .class_ids = class_ids
2323 self .tracklet_len = 0
24+ self .mask = mask
2425
2526 def predict (self ):
2627 mean_state = self .mean .copy ()
@@ -74,6 +75,7 @@ def re_activate(self, new_track, frame_id, new_id=False):
7475 if new_id :
7576 self .track_id = self .next_id ()
7677 self .score = new_track .score
78+ self .mask = new_track .mask
7779
7880 def update (self , new_track , frame_id ):
7981 """
@@ -95,6 +97,8 @@ def update(self, new_track, frame_id):
9597
9698 self .score = new_track .score
9799
100+ self .mask = new_track .mask
101+
98102 @property
99103 def tlwh (self ):
100104 """Get current position in bounding box format `(top left x, top left y,
@@ -231,9 +235,8 @@ def update_with_detections(self, detections: Detections) -> Detections:
231235 ... )
232236 ```
233237 """
234-
235238 tracks = self .update_with_tensors (
236- tensors = detections2boxes (detections = detections )
239+ tensors = detections2boxes (detections = detections ), masks = detections . mask
237240 )
238241 detections = Detections .empty ()
239242 if len (tracks ) > 0 :
@@ -249,17 +252,22 @@ def update_with_detections(self, detections: Detections) -> Detections:
249252 detections .confidence = np .array (
250253 [t .score for t in tracks ], dtype = np .float32
251254 )
255+ detections .mask = np .array ([t .mask for t in tracks ], dtype = bool )
256+
252257 else :
253258 detections .tracker_id = np .array ([], dtype = int )
254259
255260 return detections
256261
257- def update_with_tensors (self , tensors : np .ndarray ) -> List [STrack ]:
262+ def update_with_tensors (
263+ self , tensors : np .ndarray , masks : Optional [np .array ] = None
264+ ) -> List [STrack ]:
258265 """
259266 Updates the tracker with the provided tensors and returns the updated tracks.
260267
261268 Parameters:
262269 tensors: The new tensors to update with.
270+ masks: The new masks associated to new tensors
263271
264272 Returns:
265273 List[STrack]: Updated tracks.
@@ -281,6 +289,12 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
281289 inds_second = np .logical_and (inds_low , inds_high )
282290 dets_second = bboxes [inds_second ]
283291 dets = bboxes [remain_inds ]
292+ if masks is not None :
293+ masks_keep = masks [remain_inds ]
294+ masks_second = masks [inds_second ]
295+ else :
296+ masks_keep = np .array ([None ] * len (remain_inds ))
297+ masks_second = np .array ([None ] * len (inds_second ))
284298 scores_keep = scores [remain_inds ]
285299 scores_second = scores [inds_second ]
286300
@@ -290,8 +304,10 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
290304 if len (dets ) > 0 :
291305 """Detections"""
292306 detections = [
293- STrack (STrack .tlbr_to_tlwh (tlbr ), s , c )
294- for (tlbr , s , c ) in zip (dets , scores_keep , class_ids_keep )
307+ STrack (STrack .tlbr_to_tlwh (tlbr ), s , c , m )
308+ for (tlbr , s , c , m ) in zip (
309+ dets , scores_keep , class_ids_keep , masks_keep
310+ )
295311 ]
296312 else :
297313 detections = []
@@ -331,8 +347,10 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
331347 if len (dets_second ) > 0 :
332348 """Detections"""
333349 detections_second = [
334- STrack (STrack .tlbr_to_tlwh (tlbr ), s , c )
335- for (tlbr , s , c ) in zip (dets_second , scores_second , class_ids_second )
350+ STrack (STrack .tlbr_to_tlwh (tlbr ), s , c , m )
351+ for (tlbr , s , c , m ) in zip (
352+ dets_second , scores_second , class_ids_second , masks_second
353+ )
336354 ]
337355 else :
338356 detections_second = []
0 commit comments