Integrate ByteTracker and BoT-SORT trackers (#788)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -155,7 +155,8 @@ class YOLO:
|
||||
overrides = self.overrides.copy()
|
||||
overrides["conf"] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "predict"
|
||||
overrides["mode"] = kwargs.get("mode", "predict")
|
||||
assert overrides["mode"] in ['track', 'predict']
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
@ -165,6 +166,16 @@ class YOLO:
|
||||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
@smart_inference_mode()
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker.track import register_tracker
|
||||
register_tracker(self)
|
||||
# bytetrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get("conf") or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
|
@ -83,7 +83,6 @@ class BasePredictor:
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.bs = None
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.classes = self.args.classes
|
||||
@ -136,7 +135,6 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
||||
|
||||
def stream_inference(self, source=None, model=None):
|
||||
self.run_callbacks("on_predict_start")
|
||||
if self.args.verbose:
|
||||
LOGGER.info("")
|
||||
|
||||
@ -155,6 +153,7 @@ class BasePredictor:
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
|
||||
self.run_callbacks("on_predict_start")
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
self.batch = batch
|
||||
@ -172,8 +171,12 @@ class BasePredictor:
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
self.results = self.postprocess(preds, im, im0s, self.classes)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
|
||||
# visualize, save, write results
|
||||
for i in range(len(im)):
|
||||
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
|
||||
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img else (path,
|
||||
im0s)
|
||||
p = Path(p)
|
||||
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
@ -184,7 +187,6 @@ class BasePredictor:
|
||||
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
yield from self.results
|
||||
|
||||
|
@ -44,6 +44,14 @@ class Results:
|
||||
setattr(r, item, getattr(self, item)[idx])
|
||||
return r
|
||||
|
||||
def update(self, boxes=None, masks=None, probs=None):
|
||||
if boxes is not None:
|
||||
self.boxes = Boxes(boxes, self.orig_shape)
|
||||
if masks is not None:
|
||||
self.masks = Masks(masks, self.orig_shape)
|
||||
if boxes is not None:
|
||||
self.probs = probs
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
@ -138,7 +146,10 @@ class Boxes:
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
assert boxes.shape[-1] == 6 # xyxy, conf, cls
|
||||
n = boxes.shape[-1]
|
||||
assert n in {6, 7}, f"expected `n` in [6, 7], but got {n}" # xyxy, (track_id), conf, cls
|
||||
# TODO
|
||||
self.is_track = n == 7
|
||||
self.boxes = boxes
|
||||
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
||||
else np.asarray(orig_shape)
|
||||
@ -155,6 +166,10 @@ class Boxes:
|
||||
def cls(self):
|
||||
return self.boxes[:, -1]
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self.boxes[:, -3] if self.is_track else None
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
@ -303,7 +318,7 @@ class Masks:
|
||||
|
||||
def __getitem__(self, idx):
|
||||
masks = self.masks[idx]
|
||||
return Masks(masks, self.im_shape, self.orig_shape)
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
name = self.__class__.__name__
|
||||
|
Reference in New Issue
Block a user