Integrate ByteTracker and BoT-SORT trackers (#788)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing
2023-02-16 00:23:03 +08:00
committed by GitHub
parent d99e04daa1
commit ed6c54da7a
24 changed files with 1635 additions and 19 deletions

View File

@ -197,9 +197,7 @@ def entrypoint(debug=''):
# Define tasks and modes
tasks = 'detect', 'segment', 'classify'
modes = 'train', 'val', 'predict', 'export'
# Define special commands
modes = 'train', 'val', 'predict', 'export', 'track'
special = {
'help': lambda: LOGGER.info(CLI_HELP_MSG),
'checks': checks.check_yolo,
@ -288,7 +286,7 @@ def entrypoint(debug=''):
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
task = model.task
overrides['task'] = task
if mode == 'predict' and 'source' not in overrides:
if mode in {'predict', 'track'} and 'source' not in overrides:
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
else "https://ultralytics.com/images/bus.jpg"
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")

View File

@ -64,7 +64,7 @@ augment: False # apply image augmentation to prediction sources
agnostic_nms: False # class-agnostic NMS
classes: # filter results by class, i.e. class=0, or class=[0,2,3]
retina_masks: False # use high-resolution segmentation masks
boxes: True # Show boxes in segmentation predictions
boxes: True # Show boxes in segmentation predictions
# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # format to export to
@ -110,3 +110,7 @@ cfg: # for overriding defaults.yaml
# Debug, do not modify -------------------------------------------------------------------------------------------------
v5loader: False # use legacy YOLOv5 dataloader
# Tracker settings ------------------------------------------------------------------------------------------------------
tracker: botsort # tracker type, ['botsort', 'bytetrack']
tracker_cfg: null # path to tracker config file

View File

@ -155,7 +155,8 @@ class YOLO:
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides.update(kwargs)
overrides["mode"] = "predict"
overrides["mode"] = kwargs.get("mode", "predict")
assert overrides["mode"] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
@ -165,6 +166,16 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker.track import register_tracker
register_tracker(self)
# bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1
kwargs['conf'] = conf
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
def val(self, data=None, **kwargs):
"""

View File

@ -83,7 +83,6 @@ class BasePredictor:
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.bs = None
self.imgsz = None
self.device = None
self.classes = self.args.classes
@ -136,7 +135,6 @@ class BasePredictor:
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")
if self.args.verbose:
LOGGER.info("")
@ -155,6 +153,7 @@ class BasePredictor:
self.done_warmup = True
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
self.run_callbacks("on_predict_start")
for batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
self.batch = batch
@ -172,8 +171,12 @@ class BasePredictor:
# postprocess
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes)
self.run_callbacks("on_predict_postprocess_end")
# visualize, save, write results
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
im0s)
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -184,7 +187,6 @@ class BasePredictor:
if self.args.save:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
self.run_callbacks("on_predict_batch_end")
yield from self.results

View File

@ -44,6 +44,14 @@ class Results:
setattr(r, item, getattr(self, item)[idx])
return r
def update(self, boxes=None, masks=None, probs=None):
if boxes is not None:
self.boxes = Boxes(boxes, self.orig_shape)
if masks is not None:
self.masks = Masks(masks, self.orig_shape)
if boxes is not None:
self.probs = probs
def cpu(self):
r = Results(orig_shape=self.orig_shape)
for item in self.comp:
@ -138,7 +146,10 @@ class Boxes:
def __init__(self, boxes, orig_shape) -> None:
if boxes.ndim == 1:
boxes = boxes[None, :]
assert boxes.shape[-1] == 6 # xyxy, conf, cls
n = boxes.shape[-1]
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
# TODO
self.is_track = n == 7
self.boxes = boxes
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
else np.asarray(orig_shape)
@ -155,6 +166,10 @@ class Boxes:
def cls(self):
return self.boxes[:, -1]
@property
def id(self):
return self.boxes[:, -3] if self.is_track else None
@property
@lru_cache(maxsize=2) # maxsize 1 should suffice
def xywh(self):
@ -303,7 +318,7 @@ class Masks:
def __getitem__(self, idx):
masks = self.masks[idx]
return Masks(masks, self.im_shape, self.orig_shape)
return Masks(masks, self.orig_shape)
def __getattr__(self, attr):
name = self.__class__.__name__

View File

@ -91,6 +91,10 @@ def on_predict_batch_end(predictor):
pass
def on_predict_postprocess_end(predictor):
pass
def on_predict_end(predictor):
pass
@ -130,6 +134,7 @@ default_callbacks = {
# Run in predictor
'on_predict_start': [on_predict_start],
'on_predict_batch_start': [on_predict_batch_start],
'on_predict_postprocess_end': [on_predict_postprocess_end],
'on_predict_batch_end': [on_predict_batch_end],
'on_predict_end': [on_predict_end],

View File

@ -250,7 +250,7 @@ def check_file(file, suffix='', download=True):
return file
else: # search
files = []
for d in 'models', 'yolo/data': # search directories
for d in 'models', 'yolo/data', 'tracker/cfg': # search directories
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
if not files:
raise FileNotFoundError(f"'{file}' does not exist")

View File

@ -68,8 +68,8 @@ class DetectionPredictor(BasePredictor):
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
c = int(cls) # integer class
label = None if self.args.hide_labels else (
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if self.args.save_crop:
save_one_box(d.xyxy,

View File

@ -82,8 +82,8 @@ class SegmentationPredictor(DetectionPredictor):
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
c = int(cls) # integer class
label = None if self.args.hide_labels else (
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
if self.args.save_crop:
save_one_box(d.xyxy,