Integrate ByteTracker and BoT-SORT trackers (#788)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -197,9 +197,7 @@ def entrypoint(debug=''):
|
||||
|
||||
# Define tasks and modes
|
||||
tasks = 'detect', 'segment', 'classify'
|
||||
modes = 'train', 'val', 'predict', 'export'
|
||||
|
||||
# Define special commands
|
||||
modes = 'train', 'val', 'predict', 'export', 'track'
|
||||
special = {
|
||||
'help': lambda: LOGGER.info(CLI_HELP_MSG),
|
||||
'checks': checks.check_yolo,
|
||||
@ -288,7 +286,7 @@ def entrypoint(debug=''):
|
||||
f"Inheriting 'task={model.task}' from {overrides['model']} and ignoring 'task={task}'.")
|
||||
task = model.task
|
||||
overrides['task'] = task
|
||||
if mode == 'predict' and 'source' not in overrides:
|
||||
if mode in {'predict', 'track'} and 'source' not in overrides:
|
||||
overrides['source'] = DEFAULT_CFG.source or ROOT / "assets" if (ROOT / "assets").exists() \
|
||||
else "https://ultralytics.com/images/bus.jpg"
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.")
|
||||
|
@ -64,7 +64,7 @@ augment: False # apply image augmentation to prediction sources
|
||||
agnostic_nms: False # class-agnostic NMS
|
||||
classes: # filter results by class, i.e. class=0, or class=[0,2,3]
|
||||
retina_masks: False # use high-resolution segmentation masks
|
||||
boxes: True # Show boxes in segmentation predictions
|
||||
boxes: True # Show boxes in segmentation predictions
|
||||
|
||||
# Export settings ------------------------------------------------------------------------------------------------------
|
||||
format: torchscript # format to export to
|
||||
@ -110,3 +110,7 @@ cfg: # for overriding defaults.yaml
|
||||
|
||||
# Debug, do not modify -------------------------------------------------------------------------------------------------
|
||||
v5loader: False # use legacy YOLOv5 dataloader
|
||||
|
||||
# Tracker settings ------------------------------------------------------------------------------------------------------
|
||||
tracker: botsort # tracker type, ['botsort', 'bytetrack']
|
||||
tracker_cfg: null # path to tracker config file
|
||||
|
@ -155,7 +155,8 @@ class YOLO:
|
||||
overrides = self.overrides.copy()
|
||||
overrides["conf"] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "predict"
|
||||
overrides["mode"] = kwargs.get("mode", "predict")
|
||||
assert overrides["mode"] in ['track', 'predict']
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
@ -165,6 +166,16 @@ class YOLO:
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
@smart_inference_mode()
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker.track import register_tracker
|
||||
register_tracker(self)
|
||||
# bytetrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get("conf") or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
|
@ -83,7 +83,6 @@ class BasePredictor:
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.bs = None
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.classes = self.args.classes
|
||||
@ -136,7 +135,6 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
def stream_inference(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
if self.args.verbose:
|
||||
LOGGER.info("")
|
||||
|
||||
@ -155,6 +153,7 @@ class BasePredictor:
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||
self.run_callbacks("on_predict_start")
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
self.batch = batch
|
||||
@ -172,8 +171,12 @@ class BasePredictor:
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
|
||||
# visualize, save, write results
|
||||
for i in range(len(im)):
|
||||
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
|
||||
im0s)
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
@ -184,7 +187,6 @@ class BasePredictor:
|
||||
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
yield from self.results
|
||||
|
||||
|
@ -44,6 +44,14 @@ class Results:
|
||||
setattr(r, item, getattr(self, item)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if boxes is not None:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
@ -138,7 +146,10 @@ class Boxes:
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
assert boxes.shape[-1] == 6 # xyxy, conf, cls
|
||||
n = boxes.shape[-1]
|
||||
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
|
||||
# TODO
|
||||
self.is_track = n == 7
|
||||
self.boxes = boxes
|
||||
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
||||
else np.asarray(orig_shape)
|
||||
@ -155,6 +166,10 @@ class Boxes:
|
||||
def cls(self):
|
||||
return self.boxes[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.boxes[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
@ -303,7 +318,7 @@ class Masks:
|
||||
|
||||
def __getitem__(self, idx):
|
||||
masks = self.masks[idx]
|
||||
return Masks(masks, self.im_shape, self.orig_shape)
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
|
@ -91,6 +91,10 @@ def on_predict_batch_end(predictor):
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
pass
|
||||
|
||||
@ -130,6 +134,7 @@ default_callbacks = {
|
||||
# Run in predictor
|
||||
'on_predict_start': [on_predict_start],
|
||||
'on_predict_batch_start': [on_predict_batch_start],
|
||||
'on_predict_postprocess_end': [on_predict_postprocess_end],
|
||||
'on_predict_batch_end': [on_predict_batch_end],
|
||||
'on_predict_end': [on_predict_end],
|
||||
|
||||
|
@ -250,7 +250,7 @@ def check_file(file, suffix='', download=True):
|
||||
return file
|
||||
else: # search
|
||||
files = []
|
||||
for d in 'models', 'yolo/data': # search directories
|
||||
for d in 'models', 'yolo/data', 'tracker/cfg': # search directories
|
||||
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
||||
if not files:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
|
@ -68,8 +68,8 @@ class DetectionPredictor(BasePredictor):
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
|
||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
save_one_box(d.xyxy,
|
||||
|
@ -82,8 +82,8 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
name = f"id:{int(d.id.item())} {self.model.names[c]}" if d.id is not None else self.model.names[c]
|
||||
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True)) if self.args.boxes else None
|
||||
if self.args.save_crop:
|
||||
save_one_box(d.xyxy,
|
||||
|
Reference in New Issue
Block a user