ultralytics 8.0.81
single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -6,6 +6,8 @@ import numpy as np
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""Enumeration of possible object tracking states."""
|
||||
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
@ -13,6 +15,8 @@ class TrackState:
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
"""Base class for object tracking, handling basic track attributes and operations."""
|
||||
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
@ -32,28 +36,36 @@ class BaseTrack:
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
"""Return the last frame ID of the track."""
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
"""Increment and return the global track ID counter."""
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""Activate the track with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""Predict the next state of the track."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Update the track with new observations."""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
"""Mark the track as lost."""
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
"""Mark the track as removed."""
|
||||
self.state = TrackState.Removed
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Reset the global track ID counter."""
|
||||
BaseTrack._count = 0
|
||||
|
@ -15,6 +15,7 @@ class BOTrack(STrack):
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
@ -25,6 +26,7 @@ class BOTrack(STrack):
|
||||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
"""Update features vector and smooth it using exponential moving average."""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
@ -35,6 +37,7 @@ class BOTrack(STrack):
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""Predicts the mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
@ -43,11 +46,13 @@ class BOTrack(STrack):
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a track with updated features and optionally assigns a new ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
@ -65,6 +70,7 @@ class BOTrack(STrack):
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
@ -79,6 +85,7 @@ class BOTrack(STrack):
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
@ -94,6 +101,7 @@ class BOTrack(STrack):
|
||||
class BOTSORT(BYTETracker):
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
@ -106,9 +114,11 @@ class BOTSORT(BYTETracker):
|
||||
self.gmc = GMC(method=args.cmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
@ -118,6 +128,7 @@ class BOTSORT(BYTETracker):
|
||||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = (dists > self.proximity_thresh)
|
||||
|
||||
@ -133,4 +144,5 @@ class BOTSORT(BYTETracker):
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
BOTrack.multi_predict(tracks)
|
||||
|
@ -23,6 +23,7 @@ class STrack(BaseTrack):
|
||||
self.idx = tlwh[-1]
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
@ -30,6 +31,7 @@ class STrack(BaseTrack):
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
@ -44,6 +46,7 @@ class STrack(BaseTrack):
|
||||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
"""Update state tracks positions and covariances using a homography matrix."""
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
@ -74,6 +77,7 @@ class STrack(BaseTrack):
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_track.tlwh))
|
||||
self.tracklet_len = 0
|
||||
@ -107,6 +111,7 @@ class STrack(BaseTrack):
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
@ -142,23 +147,27 @@ class STrack(BaseTrack):
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
"""Converts top-left bottom-right format to top-left width height format."""
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
"""Converts tlwh bounding box format to tlbr format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
|
||||
|
||||
|
||||
class BYTETracker:
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
@ -170,6 +179,7 @@ class BYTETracker:
|
||||
self.reset_id()
|
||||
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
@ -285,12 +295,15 @@ class BYTETracker:
|
||||
dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
@ -298,13 +311,16 @@ class BYTETracker:
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
def reset_id(self):
|
||||
"""Resets the ID counter of STrack."""
|
||||
STrack.reset_id()
|
||||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
@ -332,6 +348,7 @@ class BYTETracker:
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IOU distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
|
Reference in New Issue
Block a user