Integrate ByteTracker and BoT-SORT trackers (#788)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>single_channel
parent
d99e04daa1
commit
ed6c54da7a
@ -0,0 +1,32 @@
|
||||
## Tracker
|
||||
|
||||
### Trackers
|
||||
|
||||
- [x] ByteTracker
|
||||
- [x] BoT-SORT
|
||||
|
||||
### Usage
|
||||
|
||||
python interface:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("yolov8n.pt") # or a segmentation model .i.e yolov8n-seg.pt
|
||||
model.track(
|
||||
source="video/streams",
|
||||
stream=True,
|
||||
tracker="botsort.yaml/bytetrack.yaml",
|
||||
...,
|
||||
)
|
||||
```
|
||||
|
||||
cli:
|
||||
|
||||
```bash
|
||||
yolo detect track source=... tracker=...
|
||||
yolo segment track source=... tracker=...
|
||||
```
|
||||
|
||||
By default, trackers will use the configuration in `ultralytics/tracker/cfg`.
|
||||
We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/tracker/cfg`.
|
@ -0,0 +1 @@
|
||||
from .trackers import BYTETracker, BOTSORT
|
@ -0,0 +1,15 @@
|
||||
tracker_type: botsort # tracker type, ['botsort', 'bytetrack']
|
||||
track_high_thresh: 0.5 # threshold for the first association
|
||||
track_low_thresh: 0.1 # threshold for the second association
|
||||
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
|
||||
track_buffer: 30 # buffer to calculate the time when to remove tracks
|
||||
match_thresh: 0.8 # threshold for matching tracks
|
||||
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
||||
# mot20: False # for tracker evaluation(not used for now)
|
||||
|
||||
# Botsort settings
|
||||
cmc_method: sparseOptFlow # method of global motion compensation
|
||||
# ReID model related thresh (not supported yet)
|
||||
proximity_thresh: 0.5
|
||||
appearance_thresh: 0.25
|
||||
with_reid: False
|
@ -0,0 +1,8 @@
|
||||
tracker_type: bytetrack # tracker type, ['botsort', 'bytetrack']
|
||||
track_high_thresh: 0.5 # threshold for the first association
|
||||
track_low_thresh: 0.1 # threshold for the second association
|
||||
new_track_thresh: 0.6 # threshold for init new track if the detection does not match any tracks
|
||||
track_buffer: 30 # buffer to calculate the time when to remove tracks
|
||||
match_thresh: 0.8 # threshold for matching tracks
|
||||
# min_box_area: 10 # threshold for min box areas(for tracker evaluation, not used for now)
|
||||
# mot20: False # for tracker evaluation(not used for now)
|
@ -0,0 +1,41 @@
|
||||
from ultralytics.tracker import BYTETracker, BOTSORT
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
|
||||
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
|
||||
import torch
|
||||
|
||||
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
|
||||
check_requirements('lap') # for linear_assignment
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
assert cfg.tracker_type in ["bytetrack", "botsort"], \
|
||||
f"Only support 'bytetrack' and 'botsort' for now, but got '{cfg.tracker_type}'"
|
||||
trackers = []
|
||||
for _ in range(predictor.dataset.bs):
|
||||
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
|
||||
trackers.append(tracker)
|
||||
predictor.trackers = trackers
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
bs = predictor.dataset.bs
|
||||
im0s = predictor.batch[2]
|
||||
im0s = im0s if isinstance(im0s, list) else [im0s]
|
||||
for i in range(bs):
|
||||
det = predictor.results[i].boxes.cpu().numpy()
|
||||
if len(det) == 0:
|
||||
continue
|
||||
tracks = predictor.trackers[i].update(det, im0s[i])
|
||||
if len(tracks) == 0:
|
||||
continue
|
||||
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
|
||||
if predictor.results[i].masks is not None:
|
||||
idx = tracks[:, -1].tolist()
|
||||
predictor.results[i].masks = predictor.results[i].masks[idx]
|
||||
|
||||
|
||||
def register_tracker(model):
|
||||
model.add_callback("on_predict_start", on_predict_start)
|
||||
model.add_callback("on_predict_postprocess_end", on_predict_postprocess_end)
|
@ -0,0 +1,2 @@
|
||||
from .byte_tracker import BYTETracker
|
||||
from .bot_sort import BOTSORT
|
@ -0,0 +1,52 @@
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class TrackState:
|
||||
New = 0
|
||||
Tracked = 1
|
||||
Lost = 2
|
||||
Removed = 3
|
||||
|
||||
|
||||
class BaseTrack:
|
||||
_count = 0
|
||||
|
||||
track_id = 0
|
||||
is_activated = False
|
||||
state = TrackState.New
|
||||
|
||||
history = OrderedDict()
|
||||
features = []
|
||||
curr_feature = None
|
||||
score = 0
|
||||
start_frame = 0
|
||||
frame_id = 0
|
||||
time_since_update = 0
|
||||
|
||||
# multi-camera
|
||||
location = (np.inf, np.inf)
|
||||
|
||||
@property
|
||||
def end_frame(self):
|
||||
return self.frame_id
|
||||
|
||||
@staticmethod
|
||||
def next_id():
|
||||
BaseTrack._count += 1
|
||||
return BaseTrack._count
|
||||
|
||||
def activate(self, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def mark_lost(self):
|
||||
self.state = TrackState.Lost
|
||||
|
||||
def mark_removed(self):
|
||||
self.state = TrackState.Removed
|
@ -0,0 +1,132 @@
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
from ..utils import matching
|
||||
from ..utils.gmc import GMC
|
||||
from ..utils.kalman_filter import KalmanFilterXYWH
|
||||
from .byte_tracker import STrack, BYTETracker
|
||||
from .basetrack import TrackState
|
||||
|
||||
|
||||
class BOTrack(STrack):
|
||||
shared_kalman = KalmanFilterXYWH()
|
||||
|
||||
def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
|
||||
super().__init__(tlwh, score, cls)
|
||||
|
||||
self.smooth_feat = None
|
||||
self.curr_feat = None
|
||||
if feat is not None:
|
||||
self.update_features(feat)
|
||||
self.features = deque([], maxlen=feat_history)
|
||||
self.alpha = 0.9
|
||||
|
||||
def update_features(self, feat):
|
||||
feat /= np.linalg.norm(feat)
|
||||
self.curr_feat = feat
|
||||
if self.smooth_feat is None:
|
||||
self.smooth_feat = feat
|
||||
else:
|
||||
self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
|
||||
self.features.append(feat)
|
||||
self.smooth_feat /= np.linalg.norm(self.smooth_feat)
|
||||
|
||||
def predict(self):
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[6] = 0
|
||||
mean_state[7] = 0
|
||||
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().re_activate(new_track, frame_id, new_id)
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
if new_track.curr_feat is not None:
|
||||
self.update_features(new_track.curr_feat)
|
||||
super().update(new_track, frame_id)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][6] = 0
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
return self.tlwh_to_xywh(tlwh)
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xywh(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, width,
|
||||
height)`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
return ret
|
||||
|
||||
|
||||
class BOTSORT(BYTETracker):
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
super().__init__(args, frame_rate)
|
||||
# ReID module
|
||||
self.proximity_thresh = args.proximity_thresh
|
||||
self.appearance_thresh = args.appearance_thresh
|
||||
|
||||
if args.with_reid:
|
||||
# haven't supported bot-sort(reid) yet
|
||||
self.encoder = None
|
||||
# self.gmc = GMC(method=args.cmc_method, verbose=[args.name, args.ablation])
|
||||
self.gmc = GMC(method=args.cmc_method)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
return KalmanFilterXYWH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
if len(dets) == 0:
|
||||
return []
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
features_keep = self.encoder.inference(img, dets)
|
||||
detections = [BOTrack(xyxy, s, c, f) for (xyxy, s, c, f) in zip(dets, scores, cls, features_keep)]
|
||||
else:
|
||||
detections = [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)]
|
||||
return detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
dists_mask = (dists > self.proximity_thresh)
|
||||
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
|
||||
if self.args.with_reid and self.encoder is not None:
|
||||
emb_dists = matching.embedding_distance(tracks, detections) / 2.0
|
||||
emb_dists[emb_dists > self.appearance_thresh] = 1.0
|
||||
emb_dists[dists_mask] = 1.0
|
||||
dists = np.minimum(dists, emb_dists)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
BOTrack.multi_predict(tracks)
|
@ -0,0 +1,338 @@
|
||||
import numpy as np
|
||||
|
||||
from .basetrack import BaseTrack, TrackState
|
||||
from ..utils import matching
|
||||
from ..utils.kalman_filter import KalmanFilterXYAH
|
||||
|
||||
|
||||
class STrack(BaseTrack):
|
||||
shared_kalman = KalmanFilterXYAH()
|
||||
|
||||
def __init__(self, tlwh, score, cls):
|
||||
|
||||
# wait activate
|
||||
self._tlwh = np.asarray(self.tlbr_to_tlwh(tlwh[:-1]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
self.is_activated = False
|
||||
|
||||
self.score = score
|
||||
self.tracklet_len = 0
|
||||
self.cls = cls
|
||||
self.idx = tlwh[-1]
|
||||
|
||||
def predict(self):
|
||||
mean_state = self.mean.copy()
|
||||
if self.state != TrackState.Tracked:
|
||||
mean_state[7] = 0
|
||||
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
|
||||
|
||||
@staticmethod
|
||||
def multi_predict(stracks):
|
||||
if len(stracks) <= 0:
|
||||
return
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
for i, st in enumerate(stracks):
|
||||
if st.state != TrackState.Tracked:
|
||||
multi_mean[i][7] = 0
|
||||
multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
@staticmethod
|
||||
def multi_gmc(stracks, H=np.eye(2, 3)):
|
||||
if len(stracks) > 0:
|
||||
multi_mean = np.asarray([st.mean.copy() for st in stracks])
|
||||
multi_covariance = np.asarray([st.covariance for st in stracks])
|
||||
|
||||
R = H[:2, :2]
|
||||
R8x8 = np.kron(np.eye(4, dtype=float), R)
|
||||
t = H[:2, 2]
|
||||
|
||||
for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
|
||||
mean = R8x8.dot(mean)
|
||||
mean[:2] += t
|
||||
cov = R8x8.dot(cov).dot(R8x8.transpose())
|
||||
|
||||
stracks[i].mean = mean
|
||||
stracks[i].covariance = cov
|
||||
|
||||
def activate(self, kalman_filter, frame_id):
|
||||
"""Start a new tracklet"""
|
||||
self.kalman_filter = kalman_filter
|
||||
self.track_id = self.next_id()
|
||||
self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
|
||||
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
if frame_id == 1:
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
self.start_frame = frame_id
|
||||
|
||||
def re_activate(self, new_track, frame_id, new_id=False):
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_track.tlwh))
|
||||
self.tracklet_len = 0
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
self.frame_id = frame_id
|
||||
if new_id:
|
||||
self.track_id = self.next_id()
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.idx = new_track.idx
|
||||
|
||||
def update(self, new_track, frame_id):
|
||||
"""
|
||||
Update a matched track
|
||||
:type new_track: STrack
|
||||
:type frame_id: int
|
||||
:type update_feature: bool
|
||||
:return:
|
||||
"""
|
||||
self.frame_id = frame_id
|
||||
self.tracklet_len += 1
|
||||
|
||||
new_tlwh = new_track.tlwh
|
||||
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
|
||||
self.convert_coords(new_tlwh))
|
||||
self.state = TrackState.Tracked
|
||||
self.is_activated = True
|
||||
|
||||
self.score = new_track.score
|
||||
self.cls = new_track.cls
|
||||
self.idx = new_track.idx
|
||||
|
||||
def convert_coords(self, tlwh):
|
||||
return self.tlwh_to_xyah(tlwh)
|
||||
|
||||
@property
|
||||
def tlwh(self):
|
||||
"""Get current position in bounding box format `(top left x, top left y,
|
||||
width, height)`.
|
||||
"""
|
||||
if self.mean is None:
|
||||
return self._tlwh.copy()
|
||||
ret = self.mean[:4].copy()
|
||||
ret[2] *= ret[3]
|
||||
ret[:2] -= ret[2:] / 2
|
||||
return ret
|
||||
|
||||
@property
|
||||
def tlbr(self):
|
||||
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
|
||||
`(top left, bottom right)`.
|
||||
"""
|
||||
ret = self.tlwh.copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_xyah(tlwh):
|
||||
"""Convert bounding box to format `(center x, center y, aspect ratio,
|
||||
height)`, where the aspect ratio is `width / height`.
|
||||
"""
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[:2] += ret[2:] / 2
|
||||
ret[2] /= ret[3]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlbr_to_tlwh(tlbr):
|
||||
ret = np.asarray(tlbr).copy()
|
||||
ret[2:] -= ret[:2]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def tlwh_to_tlbr(tlwh):
|
||||
ret = np.asarray(tlwh).copy()
|
||||
ret[2:] += ret[:2]
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
|
||||
|
||||
|
||||
class BYTETracker:
|
||||
|
||||
def __init__(self, args, frame_rate=30):
|
||||
self.tracked_stracks = [] # type: list[STrack]
|
||||
self.lost_stracks = [] # type: list[STrack]
|
||||
self.removed_stracks = [] # type: list[STrack]
|
||||
|
||||
self.frame_id = 0
|
||||
self.args = args
|
||||
self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)
|
||||
self.kalman_filter = self.get_kalmanfilter()
|
||||
|
||||
def update(self, results, img=None):
|
||||
self.frame_id += 1
|
||||
activated_starcks = []
|
||||
refind_stracks = []
|
||||
lost_stracks = []
|
||||
removed_stracks = []
|
||||
|
||||
scores = results.conf
|
||||
bboxes = results.xyxy
|
||||
# add index
|
||||
bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
|
||||
cls = results.cls
|
||||
|
||||
remain_inds = scores > self.args.track_high_thresh
|
||||
inds_low = scores > self.args.track_low_thresh
|
||||
inds_high = scores < self.args.track_high_thresh
|
||||
|
||||
inds_second = np.logical_and(inds_low, inds_high)
|
||||
dets_second = bboxes[inds_second]
|
||||
dets = bboxes[remain_inds]
|
||||
scores_keep = scores[remain_inds]
|
||||
scores_second = scores[inds_second]
|
||||
cls_keep = cls[remain_inds]
|
||||
cls_second = cls[inds_second]
|
||||
|
||||
detections = self.init_track(dets, scores_keep, cls_keep, img)
|
||||
""" Add newly detected tracklets to tracked_stracks"""
|
||||
unconfirmed = []
|
||||
tracked_stracks = [] # type: list[STrack]
|
||||
for track in self.tracked_stracks:
|
||||
if not track.is_activated:
|
||||
unconfirmed.append(track)
|
||||
else:
|
||||
tracked_stracks.append(track)
|
||||
""" Step 2: First association, with high score detection boxes"""
|
||||
strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
|
||||
# Predict the current location with KF
|
||||
self.multi_predict(strack_pool)
|
||||
if hasattr(self, "gmc"):
|
||||
warp = self.gmc.apply(img, dets)
|
||||
STrack.multi_gmc(strack_pool, warp)
|
||||
STrack.multi_gmc(unconfirmed, warp)
|
||||
|
||||
dists = self.get_dists(strack_pool, detections)
|
||||
matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
|
||||
|
||||
for itracked, idet in matches:
|
||||
track = strack_pool[itracked]
|
||||
det = detections[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
""" Step 3: Second association, with low score detection boxes"""
|
||||
# association the untrack to the low score detections
|
||||
detections_second = self.init_track(dets_second, scores_second, cls_second, img)
|
||||
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
|
||||
# TODO
|
||||
dists = matching.iou_distance(r_tracked_stracks, detections_second)
|
||||
matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
|
||||
for itracked, idet in matches:
|
||||
track = r_tracked_stracks[itracked]
|
||||
det = detections_second[idet]
|
||||
if track.state == TrackState.Tracked:
|
||||
track.update(det, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
else:
|
||||
track.re_activate(det, self.frame_id, new_id=False)
|
||||
refind_stracks.append(track)
|
||||
|
||||
for it in u_track:
|
||||
track = r_tracked_stracks[it]
|
||||
if track.state != TrackState.Lost:
|
||||
track.mark_lost()
|
||||
lost_stracks.append(track)
|
||||
"""Deal with unconfirmed tracks, usually tracks with only one beginning frame"""
|
||||
detections = [detections[i] for i in u_detection]
|
||||
dists = self.get_dists(unconfirmed, detections)
|
||||
matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
|
||||
for itracked, idet in matches:
|
||||
unconfirmed[itracked].update(detections[idet], self.frame_id)
|
||||
activated_starcks.append(unconfirmed[itracked])
|
||||
for it in u_unconfirmed:
|
||||
track = unconfirmed[it]
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
""" Step 4: Init new stracks"""
|
||||
for inew in u_detection:
|
||||
track = detections[inew]
|
||||
if track.score < self.args.new_track_thresh:
|
||||
continue
|
||||
track.activate(self.kalman_filter, self.frame_id)
|
||||
activated_starcks.append(track)
|
||||
""" Step 5: Update state"""
|
||||
for track in self.lost_stracks:
|
||||
if self.frame_id - track.end_frame > self.max_time_lost:
|
||||
track.mark_removed()
|
||||
removed_stracks.append(track)
|
||||
|
||||
self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_starcks)
|
||||
self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
|
||||
self.lost_stracks.extend(lost_stracks)
|
||||
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
|
||||
self.removed_stracks.extend(removed_stracks)
|
||||
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
|
||||
output = [
|
||||
track.tlbr.tolist() + [track.track_id, track.score, track.cls, track.idx] for track in self.tracked_stracks
|
||||
if track.is_activated]
|
||||
return np.asarray(output, dtype=np.float32)
|
||||
|
||||
def get_kalmanfilter(self):
|
||||
return KalmanFilterXYAH()
|
||||
|
||||
def init_track(self, dets, scores, cls, img=None):
|
||||
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
|
||||
|
||||
def get_dists(self, tracks, detections):
|
||||
dists = matching.iou_distance(tracks, detections)
|
||||
# TODO: mot20
|
||||
# if not self.args.mot20:
|
||||
dists = matching.fuse_score(dists, detections)
|
||||
return dists
|
||||
|
||||
def multi_predict(self, tracks):
|
||||
STrack.multi_predict(tracks)
|
||||
|
||||
@staticmethod
|
||||
def joint_stracks(tlista, tlistb):
|
||||
exists = {}
|
||||
res = []
|
||||
for t in tlista:
|
||||
exists[t.track_id] = 1
|
||||
res.append(t)
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if not exists.get(tid, 0):
|
||||
exists[tid] = 1
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def sub_stracks(tlista, tlistb):
|
||||
stracks = {t.track_id: t for t in tlista}
|
||||
for t in tlistb:
|
||||
tid = t.track_id
|
||||
if stracks.get(tid, 0):
|
||||
del stracks[tid]
|
||||
return list(stracks.values())
|
||||
|
||||
@staticmethod
|
||||
def remove_duplicate_stracks(stracksa, stracksb):
|
||||
pdist = matching.iou_distance(stracksa, stracksb)
|
||||
pairs = np.where(pdist < 0.15)
|
||||
dupa, dupb = [], []
|
||||
for p, q in zip(*pairs):
|
||||
timep = stracksa[p].frame_id - stracksa[p].start_frame
|
||||
timeq = stracksb[q].frame_id - stracksb[q].start_frame
|
||||
if timep > timeq:
|
||||
dupb.append(q)
|
||||
else:
|
||||
dupa.append(p)
|
||||
resa = [t for i, t in enumerate(stracksa) if i not in dupa]
|
||||
resb = [t for i, t in enumerate(stracksb) if i not in dupb]
|
||||
return resa, resb
|
@ -0,0 +1,314 @@
|
||||
import copy
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GMC:
|
||||
|
||||
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
|
||||
super().__init__()
|
||||
|
||||
self.method = method
|
||||
self.downscale = max(1, int(downscale))
|
||||
|
||||
if self.method == 'orb':
|
||||
self.detector = cv2.FastFeatureDetector_create(20)
|
||||
self.extractor = cv2.ORB_create()
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
|
||||
|
||||
elif self.method == 'sift':
|
||||
self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
|
||||
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
|
||||
|
||||
elif self.method == 'ecc':
|
||||
number_of_iterations = 5000
|
||||
termination_eps = 1e-6
|
||||
self.warp_mode = cv2.MOTION_EUCLIDEAN
|
||||
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
|
||||
|
||||
elif self.method == 'sparseOptFlow':
|
||||
self.feature_params = dict(maxCorners=1000,
|
||||
qualityLevel=0.01,
|
||||
minDistance=1,
|
||||
blockSize=3,
|
||||
useHarrisDetector=False,
|
||||
k=0.04)
|
||||
# self.gmc_file = open('GMC_results.txt', 'w')
|
||||
|
||||
elif self.method in ['file', 'files']:
|
||||
seqName = verbose[0]
|
||||
ablation = verbose[1]
|
||||
if ablation:
|
||||
filePath = r'tracker/GMC_files/MOT17_ablation'
|
||||
else:
|
||||
filePath = r'tracker/GMC_files/MOTChallenge'
|
||||
|
||||
if '-FRCNN' in seqName:
|
||||
seqName = seqName[:-6]
|
||||
elif '-DPM' in seqName or '-SDP' in seqName:
|
||||
seqName = seqName[:-4]
|
||||
self.gmcFile = open(f"{filePath}/GMC-{seqName}.txt")
|
||||
|
||||
if self.gmcFile is None:
|
||||
raise ValueError(f"Error: Unable to open GMC file in directory:{filePath}")
|
||||
elif self.method in ['none', 'None']:
|
||||
self.method = 'none'
|
||||
else:
|
||||
raise ValueError(f"Error: Unknown CMC method:{method}")
|
||||
|
||||
self.prevFrame = None
|
||||
self.prevKeyPoints = None
|
||||
self.prevDescriptors = None
|
||||
|
||||
self.initializedFirstFrame = False
|
||||
|
||||
def apply(self, raw_frame, detections=None):
|
||||
if self.method in ['orb', 'sift']:
|
||||
return self.applyFeaures(raw_frame, detections)
|
||||
elif self.method == 'ecc':
|
||||
return self.applyEcc(raw_frame, detections)
|
||||
elif self.method == 'sparseOptFlow':
|
||||
return self.applySparseOptFlow(raw_frame, detections)
|
||||
elif self.method == 'file':
|
||||
return self.applyFile(raw_frame, detections)
|
||||
elif self.method == 'none':
|
||||
return np.eye(2, 3)
|
||||
else:
|
||||
return np.eye(2, 3)
|
||||
|
||||
def applyEcc(self, raw_frame, detections=None):
|
||||
|
||||
# Initialize
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3, dtype=np.float32)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Run the ECC algorithm. The results are stored in warp_matrix.
|
||||
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
|
||||
try:
|
||||
(cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
|
||||
except Exception as e:
|
||||
print(f'Warning: find transform failed. Set warp as identity {e}')
|
||||
|
||||
return H
|
||||
|
||||
def applyFeaures(self, raw_frame, detections=None):
|
||||
|
||||
# Initialize
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image (TODO: consider using pyramids)
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
width = width // self.downscale
|
||||
height = height // self.downscale
|
||||
|
||||
# find the keypoints
|
||||
mask = np.zeros_like(frame)
|
||||
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
|
||||
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(0.98 * width)] = 255
|
||||
if detections is not None:
|
||||
for det in detections:
|
||||
tlbr = (det[:4] / self.downscale).astype(np.int_)
|
||||
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
|
||||
|
||||
keypoints = self.detector.detect(frame, mask)
|
||||
|
||||
# compute the descriptors
|
||||
keypoints, descriptors = self.extractor.compute(frame, keypoints)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# Match descriptors.
|
||||
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
|
||||
|
||||
# Filtered matches based on smallest spatial distance
|
||||
matches = []
|
||||
spatialDistances = []
|
||||
|
||||
maxSpatialDistance = 0.25 * np.array([width, height])
|
||||
|
||||
# Handle empty matches case
|
||||
if len(knnMatches) == 0:
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
for m, n in knnMatches:
|
||||
if m.distance < 0.9 * n.distance:
|
||||
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
|
||||
currKeyPointLocation = keypoints[m.trainIdx].pt
|
||||
|
||||
spatialDistance = (prevKeyPointLocation[0] - currKeyPointLocation[0],
|
||||
prevKeyPointLocation[1] - currKeyPointLocation[1])
|
||||
|
||||
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
|
||||
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
|
||||
spatialDistances.append(spatialDistance)
|
||||
matches.append(m)
|
||||
|
||||
meanSpatialDistances = np.mean(spatialDistances, 0)
|
||||
stdSpatialDistances = np.std(spatialDistances, 0)
|
||||
|
||||
inliesrs = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
|
||||
|
||||
goodMatches = []
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
for i in range(len(matches)):
|
||||
if inliesrs[i, 0] and inliesrs[i, 1]:
|
||||
goodMatches.append(matches[i])
|
||||
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
|
||||
currPoints.append(keypoints[matches[i].trainIdx].pt)
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Draw the keypoint matches on the output image
|
||||
if 0:
|
||||
matches_img = np.hstack((self.prevFrame, frame))
|
||||
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
|
||||
W = np.size(self.prevFrame, 1)
|
||||
for m in goodMatches:
|
||||
prev_pt = np.array(self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
|
||||
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
|
||||
curr_pt[0] += W
|
||||
color = np.random.randint(0, 255, (3,))
|
||||
color = (int(color[0]), int(color[1]), int(color[2]))
|
||||
|
||||
matches_img = cv2.line(matches_img, prev_pt, curr_pt, tuple(color), 1, cv2.LINE_AA)
|
||||
matches_img = cv2.circle(matches_img, prev_pt, 2, tuple(color), -1)
|
||||
matches_img = cv2.circle(matches_img, curr_pt, 2, tuple(color), -1)
|
||||
|
||||
plt.figure()
|
||||
plt.imshow(matches_img)
|
||||
plt.show()
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
print('Warning: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
self.prevDescriptors = copy.copy(descriptors)
|
||||
|
||||
return H
|
||||
|
||||
def applySparseOptFlow(self, raw_frame, detections=None):
|
||||
# Initialize
|
||||
# t0 = time.time()
|
||||
height, width, _ = raw_frame.shape
|
||||
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
|
||||
H = np.eye(2, 3)
|
||||
|
||||
# Downscale image
|
||||
if self.downscale > 1.0:
|
||||
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
|
||||
frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
|
||||
|
||||
# find the keypoints
|
||||
keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
|
||||
|
||||
# Handle first frame
|
||||
if not self.initializedFirstFrame:
|
||||
# Initialize data
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# Initialization done
|
||||
self.initializedFirstFrame = True
|
||||
|
||||
return H
|
||||
|
||||
# find correspondences
|
||||
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
|
||||
|
||||
# leave good correspondences only
|
||||
prevPoints = []
|
||||
currPoints = []
|
||||
|
||||
for i in range(len(status)):
|
||||
if status[i]:
|
||||
prevPoints.append(self.prevKeyPoints[i])
|
||||
currPoints.append(matchedKeypoints[i])
|
||||
|
||||
prevPoints = np.array(prevPoints)
|
||||
currPoints = np.array(currPoints)
|
||||
|
||||
# Find rigid matrix
|
||||
if (np.size(prevPoints, 0) > 4) and (np.size(prevPoints, 0) == np.size(prevPoints, 0)):
|
||||
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
|
||||
|
||||
# Handle downscale
|
||||
if self.downscale > 1.0:
|
||||
H[0, 2] *= self.downscale
|
||||
H[1, 2] *= self.downscale
|
||||
else:
|
||||
print('Warning: not enough matching points')
|
||||
|
||||
# Store to next iteration
|
||||
self.prevFrame = frame.copy()
|
||||
self.prevKeyPoints = copy.copy(keypoints)
|
||||
|
||||
# gmc_line = str(1000 * (time.time() - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
|
||||
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
|
||||
# self.gmc_file.write(gmc_line)
|
||||
|
||||
return H
|
||||
|
||||
def applyFile(self, raw_frame, detections=None):
|
||||
line = self.gmcFile.readline()
|
||||
tokens = line.split("\t")
|
||||
H = np.eye(2, 3, dtype=np.float_)
|
||||
H[0, 0] = float(tokens[1])
|
||||
H[0, 1] = float(tokens[2])
|
||||
H[0, 2] = float(tokens[3])
|
||||
H[1, 0] = float(tokens[4])
|
||||
H[1, 1] = float(tokens[5])
|
||||
H[1, 2] = float(tokens[6])
|
||||
|
||||
return H
|
@ -0,0 +1,458 @@
|
||||
import numpy as np
|
||||
import scipy.linalg
|
||||
|
||||
# Table for the 0.95 quantile of the chi-square distribution with N degrees of freedom (contains values for N=1, ..., 9)
|
||||
# Taken from MATLAB/Octave's chi2inv function and used as Mahalanobis gating threshold.
|
||||
chi2inv95 = {1: 3.8415, 2: 5.9915, 3: 7.8147, 4: 9.4877, 5: 11.070, 6: 12.592, 7: 14.067, 8: 15.507, 9: 16.919}
|
||||
|
||||
|
||||
class KalmanFilterXYAH:
|
||||
"""
|
||||
For bytetrack
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, a, h, vx, vy, va, vh
|
||||
|
||||
contains the bounding box center position (x, y), aspect ratio a, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, a, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, a, h) with center position (x, y),
|
||||
aspect ratio a, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[3], 2 * self._std_weight_position * measurement[3], 1e-2,
|
||||
2 * self._std_weight_position * measurement[3], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[3], 1e-5, 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-2,
|
||||
self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[3], self._std_weight_velocity * mean[3], 1e-5,
|
||||
self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
# mean = np.dot(self._motion_mat, mean)
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[3], self._std_weight_position * mean[3], 1e-1,
|
||||
self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 3], self._std_weight_position * mean[:, 3],
|
||||
1e-2 * np.ones_like(mean[:, 3]), self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity * mean[:, 3],
|
||||
1e-5 * np.ones_like(mean[:, 3]), self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, a, h), where (x, y)
|
||||
is the center position, a the aspect ratio, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
||||
|
||||
|
||||
class KalmanFilterXYWH:
|
||||
"""
|
||||
For bot-sort
|
||||
A simple Kalman filter for tracking bounding boxes in image space.
|
||||
|
||||
The 8-dimensional state space
|
||||
|
||||
x, y, w, h, vx, vy, vw, vh
|
||||
|
||||
contains the bounding box center position (x, y), width w, height h,
|
||||
and their respective velocities.
|
||||
|
||||
Object motion follows a constant velocity model. The bounding box location
|
||||
(x, y, w, h) is taken as direct observation of the state space (linear
|
||||
observation model).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
ndim, dt = 4, 1.
|
||||
|
||||
# Create Kalman filter model matrices.
|
||||
self._motion_mat = np.eye(2 * ndim, 2 * ndim)
|
||||
for i in range(ndim):
|
||||
self._motion_mat[i, ndim + i] = dt
|
||||
self._update_mat = np.eye(ndim, 2 * ndim)
|
||||
|
||||
# Motion and observation uncertainty are chosen relative to the current
|
||||
# state estimate. These weights control the amount of uncertainty in
|
||||
# the model. This is a bit hacky.
|
||||
self._std_weight_position = 1. / 20
|
||||
self._std_weight_velocity = 1. / 160
|
||||
|
||||
def initiate(self, measurement):
|
||||
"""Create track from unassociated measurement.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
measurement : ndarray
|
||||
Bounding box coordinates (x, y, w, h) with center position (x, y),
|
||||
width w, and height h.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector (8 dimensional) and covariance matrix (8x8
|
||||
dimensional) of the new track. Unobserved velocities are initialized
|
||||
to 0 mean.
|
||||
|
||||
"""
|
||||
mean_pos = measurement
|
||||
mean_vel = np.zeros_like(mean_pos)
|
||||
mean = np.r_[mean_pos, mean_vel]
|
||||
|
||||
std = [
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
2 * self._std_weight_position * measurement[2], 2 * self._std_weight_position * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3],
|
||||
10 * self._std_weight_velocity * measurement[2], 10 * self._std_weight_velocity * measurement[3]]
|
||||
covariance = np.diag(np.square(std))
|
||||
return mean, covariance
|
||||
|
||||
def predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The 8 dimensional mean vector of the object state at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The 8x8 dimensional covariance matrix of the object state at the
|
||||
previous time step.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3],
|
||||
self._std_weight_velocity * mean[2], self._std_weight_velocity * mean[3]]
|
||||
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def project(self, mean, covariance):
|
||||
"""Project state distribution to measurement space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The state's mean vector (8 dimensional array).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the projected mean and covariance matrix of the given state
|
||||
estimate.
|
||||
|
||||
"""
|
||||
std = [
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3],
|
||||
self._std_weight_position * mean[2], self._std_weight_position * mean[3]]
|
||||
innovation_cov = np.diag(np.square(std))
|
||||
|
||||
mean = np.dot(self._update_mat, mean)
|
||||
covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
|
||||
return mean, covariance + innovation_cov
|
||||
|
||||
def multi_predict(self, mean, covariance):
|
||||
"""Run Kalman filter prediction step (Vectorized version).
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The Nx8 dimensional mean matrix of the object states at the previous
|
||||
time step.
|
||||
covariance : ndarray
|
||||
The Nx8x8 dimensional covariance matrics of the object states at the
|
||||
previous time step.
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the mean vector and covariance matrix of the predicted
|
||||
state. Unobserved velocities are initialized to 0 mean.
|
||||
"""
|
||||
std_pos = [
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3],
|
||||
self._std_weight_position * mean[:, 2], self._std_weight_position * mean[:, 3]]
|
||||
std_vel = [
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3],
|
||||
self._std_weight_velocity * mean[:, 2], self._std_weight_velocity * mean[:, 3]]
|
||||
sqr = np.square(np.r_[std_pos, std_vel]).T
|
||||
|
||||
motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
|
||||
motion_cov = np.asarray(motion_cov)
|
||||
|
||||
mean = np.dot(mean, self._motion_mat.T)
|
||||
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
|
||||
covariance = np.dot(left, self._motion_mat.T) + motion_cov
|
||||
|
||||
return mean, covariance
|
||||
|
||||
def update(self, mean, covariance, measurement):
|
||||
"""Run Kalman filter correction step.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
The predicted state's mean vector (8 dimensional).
|
||||
covariance : ndarray
|
||||
The state's covariance matrix (8x8 dimensional).
|
||||
measurement : ndarray
|
||||
The 4 dimensional measurement vector (x, y, w, h), where (x, y)
|
||||
is the center position, w the width, and h the height of the
|
||||
bounding box.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray, ndarray)
|
||||
Returns the measurement-corrected state distribution.
|
||||
|
||||
"""
|
||||
projected_mean, projected_cov = self.project(mean, covariance)
|
||||
|
||||
chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
|
||||
kalman_gain = scipy.linalg.cho_solve((chol_factor, lower),
|
||||
np.dot(covariance, self._update_mat.T).T,
|
||||
check_finite=False).T
|
||||
innovation = measurement - projected_mean
|
||||
|
||||
new_mean = mean + np.dot(innovation, kalman_gain.T)
|
||||
new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
|
||||
return new_mean, new_covariance
|
||||
|
||||
def gating_distance(self, mean, covariance, measurements, only_position=False, metric='maha'):
|
||||
"""Compute gating distance between state distribution and measurements.
|
||||
A suitable distance threshold can be obtained from `chi2inv95`. If
|
||||
`only_position` is False, the chi-square distribution has 4 degrees of
|
||||
freedom, otherwise 2.
|
||||
Parameters
|
||||
----------
|
||||
mean : ndarray
|
||||
Mean vector over the state distribution (8 dimensional).
|
||||
covariance : ndarray
|
||||
Covariance of the state distribution (8x8 dimensional).
|
||||
measurements : ndarray
|
||||
An Nx4 dimensional matrix of N measurements, each in
|
||||
format (x, y, a, h) where (x, y) is the bounding box center
|
||||
position, a the aspect ratio, and h the height.
|
||||
only_position : Optional[bool]
|
||||
If True, distance computation is done with respect to the bounding
|
||||
box center position only.
|
||||
Returns
|
||||
-------
|
||||
ndarray
|
||||
Returns an array of length N, where the i-th element contains the
|
||||
squared Mahalanobis distance between (mean, covariance) and
|
||||
`measurements[i]`.
|
||||
"""
|
||||
mean, covariance = self.project(mean, covariance)
|
||||
if only_position:
|
||||
mean, covariance = mean[:2], covariance[:2, :2]
|
||||
measurements = measurements[:, :2]
|
||||
|
||||
d = measurements - mean
|
||||
if metric == 'gaussian':
|
||||
return np.sum(d * d, axis=1)
|
||||
elif metric == 'maha':
|
||||
cholesky_factor = np.linalg.cholesky(covariance)
|
||||
z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
|
||||
return np.sum(z * z, axis=0) # square maha
|
||||
else:
|
||||
raise ValueError('invalid distance metric')
|
@ -0,0 +1,188 @@
|
||||
import lap
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from .kalman_filter import chi2inv95
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O, P, Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix((np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix((np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1 * M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - {i for i, j in match})
|
||||
unmatched_Q = tuple(set(range(Q)) - {j for i, j in match})
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def _indices_to_matches(cost_matrix, indices, thresh):
|
||||
matched_cost = cost_matrix[tuple(zip(*indices))]
|
||||
matched_mask = (matched_cost <= thresh)
|
||||
|
||||
matches = indices[matched_mask]
|
||||
unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
|
||||
unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
|
||||
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
matches.extend([ix, mx] for ix, mx in enumerate(x) if mx >= 0)
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def ious(atlbrs, btlbrs):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
:type atlbrs: list[tlbr] | np.ndarray
|
||||
|
||||
:rtype ious np.ndarray
|
||||
"""
|
||||
ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
|
||||
if ious.size == 0:
|
||||
return ious
|
||||
|
||||
ious = bbox_ious(np.ascontiguousarray(atlbrs, dtype=np.float32), np.ascontiguousarray(btlbrs, dtype=np.float32))
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def v_iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU
|
||||
:type atracks: list[STrack]
|
||||
:type btracks: list[STrack]
|
||||
|
||||
:rtype cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) \
|
||||
or (len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in atracks]
|
||||
btlbrs = [track.tlwh_to_tlbr(track.pred_bbox) for track in btracks]
|
||||
_ious = ious(atlbrs, btlbrs)
|
||||
return 1 - _ious # cost matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='cosine'):
|
||||
"""
|
||||
:param tracks: list[STrack]
|
||||
:param detections: list[BaseTrack]
|
||||
:param metric:
|
||||
:return: cost_matrix np.ndarray
|
||||
"""
|
||||
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
|
||||
# for i, track in enumerate(tracks):
|
||||
# cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
|
||||
track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance, measurements, only_position, metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_) * gating_distance
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_iou(cost_matrix, tracks, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
reid_sim = 1 - cost_matrix
|
||||
iou_dist = iou_distance(tracks, detections)
|
||||
iou_sim = 1 - iou_dist
|
||||
fuse_sim = reid_sim * (1 + iou_sim) / 2
|
||||
# det_scores = np.array([det.score for det in detections])
|
||||
# det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
return 1 - fuse_sim # fuse cost
|
||||
|
||||
|
||||
def fuse_score(cost_matrix, detections):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
iou_sim = 1 - cost_matrix
|
||||
det_scores = np.array([det.score for det in detections])
|
||||
det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
|
||||
fuse_sim = iou_sim * det_scores
|
||||
return 1 - fuse_sim # fuse_cost
|
||||
|
||||
|
||||
def bbox_ious(box1, box2, eps=1e-7):
|
||||
"""Boxes are x1y1x2y2
|
||||
box1: np.array of shape(nx4)
|
||||
box2: np.array of shape(mx4)
|
||||
returns: np.array of shape(nxm)
|
||||
"""
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
|
||||
|
||||
# Intersection area
|
||||
inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * \
|
||||
(np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)).clip(0)
|
||||
|
||||
# box2 area
|
||||
box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
|
||||
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
|
||||
return inter_area / (box2_area + box1_area[:, None] - inter_area + eps)
|
Loading…
Reference in new issue