ultralytics 8.0.71 updates and fixes (#1907)

Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Pavel Bugneac <50273042+pavelbugneac@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-04-08 21:10:36 +02:00
committed by GitHub
parent c38b17a0d8
commit 4e997013bc
19 changed files with 103 additions and 39 deletions

View File

@ -1,5 +1,7 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from functools import partial
import torch
from ultralytics.yolo.utils import IterableSimpleNamespace, yaml_load
@ -10,7 +12,19 @@ from .trackers import BOTSORT, BYTETracker
TRACKER_MAP = {'bytetrack': BYTETracker, 'botsort': BOTSORT}
def on_predict_start(predictor):
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
Raises:
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'.
"""
if hasattr(predictor, 'trackers') and persist:
return
tracker = check_yaml(predictor.args.tracker)
cfg = IterableSimpleNamespace(**yaml_load(tracker))
assert cfg.tracker_type in ['bytetrack', 'botsort'], \
@ -38,6 +52,14 @@ def on_predict_postprocess_end(predictor):
predictor.results[i].update(boxes=torch.as_tensor(tracks[:, :-1]))
def register_tracker(model):
model.add_callback('on_predict_start', on_predict_start)
def register_tracker(model, persist):
"""
Register tracking callbacks to the model for object tracking during prediction.
Args:
model (object): The model object to register tracking callbacks for.
persist (bool): Whether to persist the trackers if they already exist.
"""
model.add_callback('on_predict_start', partial(on_predict_start, persist=persist))
model.add_callback('on_predict_postprocess_end', on_predict_postprocess_end)

View File

@ -277,12 +277,13 @@ class BYTETracker:
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
self.lost_stracks.extend(lost_stracks)
self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
self.removed_stracks.extend(removed_stracks)
self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
output = [
track.tlbr.tolist() + [track.track_id, track.score, track.cls, track.idx] for track in self.tracked_stracks
if track.is_activated]
return np.asarray(output, dtype=np.float32)
self.removed_stracks.extend(removed_stracks)
if len(self.removed_stracks) > 1000:
self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
return np.asarray(
[x.tlbr.tolist() + [x.track_id, x.score, x.cls, x.idx] for x in self.tracked_stracks if x.is_activated],
dtype=np.float32)
def get_kalmanfilter(self):
return KalmanFilterXYAH()
@ -319,12 +320,16 @@ class BYTETracker:
@staticmethod
def sub_stracks(tlista, tlistb):
""" DEPRECATED CODE in https://github.com/ultralytics/ultralytics/pull/1890/
stracks = {t.track_id: t for t in tlista}
for t in tlistb:
tid = t.track_id
if stracks.get(tid, 0):
del stracks[tid]
return list(stracks.values())
"""
track_ids_b = {t.track_id for t in tlistb}
return [t for t in tlista if t.track_id not in track_ids_b]
@staticmethod
def remove_duplicate_stracks(stracksa, stracksb):