ultralytics 8.0.136
refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
86
ultralytics/trackers/README.md
Normal file
86
ultralytics/trackers/README.md
Normal file
@ -0,0 +1,86 @@
|
||||
# Tracker
|
||||
|
||||
## Supported Trackers
|
||||
|
||||
- [x] ByteTracker
|
||||
- [x] BoT-SORT
|
||||
|
||||
## Usage
|
||||
|
||||
### python interface:
|
||||
|
||||
You can use the Python interface to track objects using the YOLO model.
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
|
||||
model.track(
|
||||
source="video/streams",
|
||||
stream=True,
|
||||
tracker="botsort.yaml", # or 'bytetrack.yaml'
|
||||
show=True,
|
||||
)
|
||||
```
|
||||
|
||||
You can get the IDs of the tracked objects using the following code:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolov8n.pt")
|
||||
|
||||
for result in model.track(source="video.mp4"):
|
||||
print(
|
||||
result.boxes.id.cpu().numpy().astype(int)
|
||||
) # this will print the IDs of the tracked objects in the frame
|
||||
```
|
||||
|
||||
If you want to use the tracker with a folder of images or when you loop on the video frames, you should use the `persist` parameter to tell the model that these frames are related to each other so the IDs will be fixed for the same objects. Otherwise, the IDs will be different in each frame because in each loop, the model creates a new object for tracking, but the `persist` parameter makes it use the same object for tracking.
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
cap = cv2.VideoCapture("video.mp4")
|
||||
model = YOLO("yolov8n.pt")
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
results = model.track(frame, persist=True)
|
||||
boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
|
||||
ids = results[0].boxes.id.cpu().numpy().astype(int)
|
||||
for box, id in zip(boxes, ids):
|
||||
cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)
|
||||
cv2.putText(
|
||||
frame,
|
||||
f"Id {id}",
|
||||
(box[0], box[1]),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 255),
|
||||
2,
|
||||
)
|
||||
cv2.imshow("frame", frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
```
|
||||
|
||||
## Change tracker parameters
|
||||
|
||||
You can change the tracker parameters by eding the `tracker.yaml` file which is located in the ultralytics/cfg/trackers folder.
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
You can also use the command line interface to track objects using the YOLO model.
|
||||
|
||||
```bash
|
||||
yolo detect track source=... tracker=...
|
||||
yolo segment track source=... tracker=...
|
||||
yolo pose track source=... tracker=...
|
||||
```
|
||||
|
||||
By default, trackers will use the configuration in `ultralytics/cfg/trackers`.
|
||||
We also support using a modified tracker config file. Please refer to the tracker config files
|
||||
in `ultralytics/cfg/trackers`.<br>
|
7
ultralytics/trackers/__init__.py
Normal file
7
ultralytics/trackers/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
from .track import register_tracker
|
||||
|
||||
__all__ = 'register_tracker', 'BOTSORT', 'BYTETracker' # allow simpler import
|
71
ultralytics/trackers/basetrack.py
Normal file
71
ultralytics/trackers/basetrack.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TrackState:
|
||||
"""Enumeration of possible object tracking states."""
|
||||
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
"""Base class for object tracking, handling basic track attributes and operations."""
|
||||
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# Multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
"""Return the last frame ID of the track."""
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
"""Increment and return the global track ID counter."""
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
"""Activate the track with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
"""Predict the next state of the track."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
"""Update the track with new observations."""
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
"""Mark the track as lost."""
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
"""Mark the track as removed."""
|
||||
self.state = TrackState.Removed
|
||||
|
||||
@staticmethod
|
||||
def reset_id():
|
||||
"""Reset the global track ID counter."""
|
||||
BaseTrack._count = 0
|
148
ultralytics/trackers/bot_sort.py
Normal file
148
ultralytics/trackers/bot_sort.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import TrackState
|
||||
from .byte_tracker import BYTETracker, STrack
|
||||
from .utils import matching
|
||||
from .utils.gmc import GMC
|
||||
from .utils.kalman_filter import KalmanFilterXYWH
|
||||
|
||||
|
||||
class BOTrack(STrack):
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
self.curr_feat = None
|
||||
if feat is not None:
|
||||
self.update_features(feat)
|
||||
self.features = deque([], maxlen=feat_history)
|
||||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
"""Update features vector and smooth it using exponential moving average."""
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
self.smooth_feat = feat
|
||||
else:
|
||||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
|
||||
self.features.append(feat)
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
"""Predicts the mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
mean_state[7] = 0
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a track with updated features and optionally assigns a new ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""Update the YOLOv8 instance with new track and frame ID."""
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][6] = 0
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width,
|
||||
height)`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
|
||||
class BOTSORT(BYTETracker):
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
self.appearance_thresh = args.appearance_thresh
|
||||
|
||||
if args.with_reid:
|
||||
# Haven't supported BoT-SORT(reid) yet
|
||||
self.encoder = None
|
||||
# self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation])
|
||||
self.gmc = GMC(method=args.cmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns an instance of KalmanFilterXYWH for object tracking."""
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize track with detections, scores, and classes."""
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
features_keep = self.encoder.inference(img, dets)
|
||||
return [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)] # detections
|
||||
else:
|
||||
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = (dists > self.proximity_thresh)
|
||||
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
emb_dists = matching.embedding_distance(tracks, detections) / 2.0
|
||||
emb_dists[emb_dists > self.appearance_thresh] = 1.0
|
||||
emb_dists[dists_mask] = 1.0
|
||||
dists = np.minimum(dists, emb_dists)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Predict and track multiple objects with YOLOv8 model."""
|
||||
BOTrack.multi_predict(tracks)
|
364
ultralytics/trackers/byte_tracker.py
Normal file
364
ultralytics/trackers/byte_tracker.py
Normal file
@ -0,0 +1,364 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
from .utils import matching
|
||||
from .utils.kalman_filter import KalmanFilterXYAH
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, tlwh, score, cls):
|
||||
"""wait activate."""
|
||||
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
self.idx = tlwh[-1]
|
||||
|
||||
def predict(self):
|
||||
"""Predicts mean and covariance using Kalman filter."""
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
"""Update state tracks positions and covariances using a homography matrix."""
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
||||
R = H[:2, :2]
|
||||
R8x8 = np.kron(np.eye(4, dtype=float), R)
|
||||
t = H[:2, 2]
|
||||
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
mean = R8x8.dot(mean)
|
||||
mean[:2] += t
|
||||
cov = R8x8.dot(cov).dot(R8x8.transpose())
|
||||
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet."""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
"""Reactivates a previously lost track with a new detection."""
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_track.tlwh))
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.idx = new_track.idx
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
"""Converts top-left bottom-right format to top-left width height format."""
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
"""Converts tlwh bounding box format to tlbr format."""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
|
||||
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
|
||||
|
||||
|
||||
class BYTETracker:
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
||||
self.frame_id = 0
|
||||
self.args = args
|
||||
self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)
|
||||
self.kalman_filter = self.get_kalmanfilter()
|
||||
self.reset_id()
|
||||
|
||||
def update(self, results, img=None):
|
||||
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
|
||||
self.frame_id += 1
|
||||
activated_stracks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
scores = results.conf
|
||||
bboxes = results.xyxy
|
||||
# Add index
|
||||
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
||||
cls = results.cls
|
||||
|
||||
remain_inds = scores > self.args.track_high_thresh
|
||||
inds_low = scores > self.args.track_low_thresh
|
||||
inds_high = scores < self.args.track_high_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
dets_second = bboxes[inds_second]
|
||||
dets = bboxes[remain_inds]
|
||||
scores_keep = scores[remain_inds]
|
||||
scores_second = scores[inds_second]
|
||||
cls_keep = cls[remain_inds]
|
||||
cls_second = cls[inds_second]
|
||||
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img)
|
||||
# Add newly detected tracklets to tracked_stracks
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
# Step 2: First association, with high score detection boxes
|
||||
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
self.multi_predict(strack_pool)
|
||||
if hasattr(self, 'gmc') and img is not None:
|
||||
warp = self.gmc.apply(img, dets)
|
||||
STrack.multi_gmc(strack_pool, warp)
|
||||
STrack.multi_gmc(unconfirmed, warp)
|
||||
|
||||
dists = self.get_dists(strack_pool, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
# Step 3: Second association, with low score detection boxes
|
||||
# association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
# TODO
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if track.state != TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
# Deal with unconfirmed tracks, usually tracks with only one beginning frame
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = self.get_dists(unconfirmed, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_stracks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
# Step 4: Init new stracks
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.args.new_track_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_stracks.append(track)
|
||||
# Step 5: Update state
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
if len(self.removed_stracks) > 1000:
|
||||
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
|
||||
return np.asarray(
|
||||
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
|
||||
dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
"""Returns a Kalman filter object for tracking bounding boxes."""
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
"""Initialize object tracking with detections and scores using STrack algorithm."""
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
"""Returns the predicted tracks using the YOLOv8 network."""
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
def reset_id(self):
|
||||
"""Resets the ID counter of STrack."""
|
||||
STrack.reset_id()
|
||||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
"""Combine two lists of stracks into a single one."""
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def sub_stracks(tlista, tlistb):
|
||||
"""DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
|
||||
stracks = {t.track_id: t for t in tlista}
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
"""
|
||||
track_ids_b = {t.track_id for t in tlistb}
|
||||
return [t for t in tlista if t.track_id not in track_ids_b]
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
"""Remove duplicate stracks with non-maximum IOU distance."""
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if i not in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if i not in dupb]
|
||||
return resa, resb
|
66
ultralytics/trackers/track.py
Normal file
66
ultralytics/trackers/track.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import IterableSimpleNamespace, yaml_load
|
||||
from ultralytics.utils.checks import check_yaml
|
||||
|
||||
from .bot_sort import BOTSORT
|
||||
from .byte_tracker import BYTETracker
|
||||
|
||||
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
|
||||
|
||||
|
||||
def on_predict_start(predictor, persist=False):
|
||||
"""
|
||||
Initialize trackers for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
predictor (object): The predictor object to initialize trackers for.
|
||||
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
|
||||
"""
|
||||
if hasattr(predictor, 'trackers') and persist:
|
||||
return
|
||||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
|
||||
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
|
||||
trackers = []
|
||||
for _ in range(predictor.dataset.bs):
|
||||
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
|
||||
trackers.append(tracker)
|
||||
predictor.trackers = trackers
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Postprocess detected boxes and update with object tracking."""
|
||||
bs = predictor.dataset.bs
|
||||
im0s = predictor.batch[1]
|
||||
for i in range(bs):
|
||||
det = predictor.results[i].boxes.cpu().numpy()
|
||||
if len(det) == 0:
|
||||
continue
|
||||
tracks = predictor.trackers[i].update(det, im0s[i])
|
||||
if len(tracks) == 0:
|
||||
continue
|
||||
idx = tracks[:, -1].astype(int)
|
||||
predictor.results[i] = predictor.results[i][idx]
|
||||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||
|
||||
|
||||
def register_tracker(model, persist):
|
||||
"""
|
||||
Register tracking callbacks to the model for object tracking during prediction.
|
||||
|
||||
Args:
|
||||
model (object): The model object to register tracking callbacks for.
|
||||
persist (bool): Whether to persist the trackers if they already exist.
|
||||
|
||||
"""
|
||||
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
|
||||
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)
|
0
ultralytics/trackers/utils/__init__.py
Normal file
0
ultralytics/trackers/utils/__init__.py
Normal file
319
ultralytics/trackers/utils/gmc.py
Normal file
319
ultralytics/trackers/utils/gmc.py
Normal file
@ -0,0 +1,319 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import copy
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
|
||||
|
||||
class GMC:
|
||||
|
||||
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
|
||||
"""Initialize a video tracker with specified parameters."""
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == 'orb':
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == 'sift':
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == 'ecc':
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == 'sparseOptFlow':
|
||||
self.feature_params = dict(maxCorners=1000,
|
||||
qualityLevel=0.01,
|
||||
minDistance=1,
|
||||
blockSize=3,
|
||||
useHarrisDetector=False,
|
||||
k=0.04)
|
||||
# self.gmc_file = open('GMC_results.txt', 'w')
|
||||
|
||||
elif self.method in ['file', 'files']:
|
||||
seqName = verbose[0]
|
||||
ablation = verbose[1]
|
||||
if ablation:
|
||||
filePath = r'tracker/GMC_files/MOT17_ablation'
|
||||
else:
|
||||
filePath = r'tracker/GMC_files/MOTChallenge'
|
||||
|
||||
if '-FRCNN' in seqName:
|
||||
seqName = seqName[:-6]
|
||||
elif '-DPM' in seqName or '-SDP' in seqName:
|
||||
seqName = seqName[:-4]
|
||||
self.gmcFile = open(f'{filePath}/GMC-{seqName}.txt')
|
||||
|
||||
if self.gmcFile is None:
|
||||
raise ValueError(f'Error: Unable to open GMC file in directory:{filePath}')
|
||||
elif self.method in ['none', 'None']:
|
||||
self.method = 'none'
|
||||
else:
|
||||
raise ValueError(f'Error: Unknown CMC method:{method}')
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame, detections=None):
|
||||
"""Apply object detection on a raw frame using specified method."""
|
||||
if self.method in ['orb', 'sift']:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
elif self.method == 'ecc':
|
||||
return self.applyEcc(raw_frame, detections)
|
||||
elif self.method == 'sparseOptFlow':
|
||||
return self.applySparseOptFlow(raw_frame, detections)
|
||||
elif self.method == 'file':
|
||||
return self.applyFile(raw_frame, detections)
|
||||
elif self.method == 'none':
|
||||
return np.eye(2, 3)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING: find transform failed. Set warp as identity {e}')
|
||||
|
||||
return H
|
||||
|
||||
def applyFeatures(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
||||
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
# Compute the descriptors
|
||||
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors.
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filtered matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
maxSpatialDistance = 0.25 * np.array([width, height])
|
||||
|
||||
# Handle empty matches case
|
||||
if len(knnMatches) == 0:
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
for m, n in knnMatches:
|
||||
if m.distance < 0.9 * n.distance:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
||||
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
meanSpatialDistances = np.mean(spatialDistances, 0)
|
||||
stdSpatialDistances = np.std(spatialDistances, 0)
|
||||
|
||||
inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
||||
|
||||
goodMatches = []
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
for i in range(len(matches)):
|
||||
if inliers[i, 0] and inliers[i, 1]:
|
||||
goodMatches.append(matches[i])
|
||||
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
|
||||
currPoints.append(keypoints[matches[i].trainIdx].pt)
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Draw the keypoint matches on the output image
|
||||
# if False:
|
||||
# import matplotlib.pyplot as plt
|
||||
# matches_img = np.hstack((self.prevFrame, frame))
|
||||
# matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
# W = np.size(self.prevFrame, 1)
|
||||
# for m in goodMatches:
|
||||
# prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
# curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
# curr_pt[0] += W
|
||||
# color = np.random.randint(0, 255, 3)
|
||||
# color = (int(color[0]), int(color[1]), int(color[2]))
|
||||
#
|
||||
# matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
||||
# matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
||||
# matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
||||
#
|
||||
# plt.figure()
|
||||
# plt.imshow(matches_img)
|
||||
# plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame, detections=None):
|
||||
"""Initialize."""
|
||||
# t0 = time.time()
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# Find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Find correspondences
|
||||
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# Leave good correspondences only
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
|
||||
for i in range(len(status)):
|
||||
if status[i]:
|
||||
prevPoints.append(self.prevKeyPoints[i])
|
||||
currPoints.append(matchedKeypoints[i])
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
LOGGER.warning('WARNING: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# gmc_line = str(1000 * (time.time() - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
|
||||
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
|
||||
# self.gmc_file.write(gmc_line)
|
||||
|
||||
return H
|
||||
|
||||
def applyFile(self, raw_frame, detections=None):
|
||||
"""Return the homography matrix based on the GCPs in the next line of the input GMC file."""
|
||||
line = self.gmcFile.readline()
|
||||
tokens = line.split('\t')
|
||||
H = np.eye(2, 3, dtype=np.float_)
|
||||
H[0, 0] = float(tokens[1])
|
||||
H[0, 1] = float(tokens[2])
|
||||
H[0, 2] = float(tokens[3])
|
||||
H[1, 0] = float(tokens[4])
|
||||
H[1, 1] = float(tokens[5])
|
||||
H[1, 2] = float(tokens[6])
|
||||
|
||||
return H
|
462
ultralytics/trackers/utils/kalman_filter.py
Normal file
462
ultralytics/trackers/utils/kalman_filter.py
Normal file
@ -0,0 +1,462 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
# Table for the 0.95 quantile of the chi-square distribution with N degrees of freedom (contains values for N=1, ..., 9)
|
||||
# Taken from MATLAB/Octave's chi2inv function and used as Mahalanobis gating threshold.
|
||||
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
|
||||
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
# mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
|
||||
|
||||
class KalmanFilterXYWH:
|
||||
"""
|
||||
For BoT-SORT
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, w, h, vx, vy, vw, vh
|
||||
|
||||
contains the bounding box center position (x, y), width w, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Kalman filter model matrices with motion and observation uncertainties."""
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
||||
width w, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrix of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
||||
is the center position, w the width, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
229
ultralytics/trackers/utils/matching.py
Normal file
229
ultralytics/trackers/utils/matching.py
Normal file
@ -0,0 +1,229 @@
|
||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from .kalman_filter import chi2inv95
|
||||
|
||||
try:
|
||||
import lap # for linear_assignment
|
||||
|
||||
assert lap.__version__ # verify package is not directory
|
||||
except (ImportError, AssertionError, AttributeError):
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
check_requirements('lapx>=0.5.2') # update to lap package from https://github.com/rathaROG/lapx
|
||||
import lap
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
"""Merge two sets of matches and return matched and unmatched indices."""
|
||||
O, P, Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1 * M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - {i for i, j in match})
|
||||
unmatched_Q = tuple(set(range(Q)) - {j for i, j in match})
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
"""_indices_to_matches: Return matched and unmatched indices given a cost matrix, indices, and a threshold."""
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh, use_lap=True):
|
||||
"""Linear assignment implementations with scipy and lap.lapjv."""
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
|
||||
if use_lap:
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
else:
|
||||
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
|
||||
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
|
||||
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
|
||||
unmatched = np.ones(cost_matrix.shape)
|
||||
for i, xi in matches:
|
||||
unmatched[i, xi] = 0.0
|
||||
unmatched_a = np.where(unmatched.all(1))[0]
|
||||
unmatched_b = np.where(unmatched.all(0))[0]
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32))
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
# for i, track in enumerate(tracks):
|
||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
"""Apply gating to the cost matrix based on predicted tracks and detected objects."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
"""Fuse motion between tracks and detections with gating and Kalman filtering."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
"""Fuses ReID and IoU similarity matrices to yield a cost matrix for object tracking."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
# det_scores = np.array([det.score for det in detections])
|
||||
# det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
return 1 - fuse_sim # fuse cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
"""Fuses cost matrix with detection scores to produce a single similarity matrix."""
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
return 1 - fuse_sim # fuse_cost
|
||||
|
||||
|
||||
def bbox_ious(box1, box2, eps=1e-7):
|
||||
"""
|
||||
Calculate the Intersection over Union (IoU) between pairs of bounding boxes.
|
||||
|
||||
Args:
|
||||
box1 (np.array): A numpy array of shape (n, 4) representing 'n' bounding boxes.
|
||||
Each row is in the format (x1, y1, x2, y2).
|
||||
box2 (np.array): A numpy array of shape (m, 4) representing 'm' bounding boxes.
|
||||
Each row is in the format (x1, y1, x2, y2).
|
||||
eps (float, optional): A small constant to prevent division by zero. Defaults to 1e-7.
|
||||
|
||||
Returns:
|
||||
(np.array): A numpy array of shape (n, m) representing the IoU scores for each pair
|
||||
of bounding boxes from box1 and box2.
|
||||
|
||||
Note:
|
||||
The bounding box coordinates are expected to be in the format (x1, y1, x2, y2).
|
||||
"""
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
||||
return inter_area / (box2_area + box1_area[:, None] - inter_area + eps)
|
Reference in New Issue
Block a user