Integrate ByteTracker and BoT-SORT trackers (#788)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing
2023-02-16 00:23:03 +08:00
committed by GitHub
parent d99e04daa1
commit ed6c54da7a
24 changed files with 1635 additions and 19 deletions

View File

@ -155,7 +155,8 @@ class YOLO:
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["mode"] = kwargs.get("mode", "predict")
assert overrides["mode"] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
@ -165,6 +166,16 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker.track import register_tracker
register_tracker(self)
# bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1
kwargs['conf'] = conf
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""

View File

@ -83,7 +83,6 @@ class BasePredictor:
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.bs = None
self.imgsz = None
self.device = None
self.classes = self.args.classes
@ -136,7 +135,6 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")
if self.args.verbose:
LOGGER.info("")
@ -155,6 +153,7 @@ class BasePredictor:
self.done_warmup = True
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
self.run_callbacks("on_predict_start")
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
self.batch = batch
@ -172,8 +171,12 @@ class BasePredictor:
# postprocess
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes)
self.run_callbacks("on_predict_postprocess_end")
# visualize, save, write results
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
im0s)
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -184,7 +187,6 @@ class BasePredictor:
if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks("on_predict_batch_end")
yield from self.results

View File

@ -44,6 +44,14 @@ class Results:
setattr(r, item, getattr(self, item)[idx])
return r
def update(self, boxes=None, masks=None, probs=None):
if boxes is not None:
self.boxes = Boxes(boxes, self.orig_shape)
if masks is not None:
self.masks = Masks(masks, self.orig_shape)
if boxes is not None:
self.probs = probs
def cpu(self):
r = Results(orig_shape=self.orig_shape)
for item in self.comp:
@ -138,7 +146,10 @@ class Boxes:
def __init__(self, boxes, orig_shape) -> None:
if boxes.ndim == 1:
boxes = boxes[None, :]
assert boxes.shape[-1] == 6 # xyxy, conf, cls
n = boxes.shape[-1]
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
# TODO
self.is_track = n == 7
self.boxes = boxes
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
else np.asarray(orig_shape)
@ -155,6 +166,10 @@ class Boxes:
def cls(self):
return self.boxes[:, -1]
@property
def id(self):
return self.boxes[:, -3] if self.is_track else None
@property
@lru_cache(maxsize=2) # maxsize 1 should suffice
def xywh(self):
@ -303,7 +318,7 @@ class Masks:
def __getitem__(self, idx):
masks = self.masks[idx]
return Masks(masks, self.im_shape, self.orig_shape)
return Masks(masks, self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__