Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-29 05:44:26 +05:30
committed by GitHub
parent a5410ed79e
commit 0609561549
9 changed files with 174 additions and 73 deletions

View File

@ -233,6 +233,13 @@ class YOLO:
"""
return self.model.names
@property
def transforms(self):
"""
Returns transform of the loaded model.
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@staticmethod
def add_callback(event: str, func):
"""

View File

@ -30,13 +30,13 @@ from collections import defaultdict
from pathlib import Path
import cv2
import torch
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.data import load_inference_source
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -76,6 +76,8 @@ class BasePredictor:
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done
self.model = None
@ -88,6 +90,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None
self.annotator = None
self.data_path = None
self.source_type = None
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self)
@ -103,53 +106,6 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img, classes=None):
return preds
def setup_source(self, source=None):
if not self.model:
raise Exception("setup model before setting up source!")
# source
source, webcam, screenshot, from_img = self.check_source(source)
# model
stride, pt = self.model.stride, self.model.pt
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
# Dataloader
bs = 1 # batch_size
if webcam:
self.args.show = check_imshow(warn=True)
self.dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
bs = len(self.dataset)
elif screenshot:
self.dataset = LoadScreenshots(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
elif from_img:
self.dataset = LoadPilAndNumpy(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
else:
self.dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
self.webcam = webcam
self.screenshot = screenshot
self.from_img = from_img
self.imgsz = imgsz
self.bs = bs
@smart_inference_mode()
def __call__(self, source=None, model=None, stream=False):
if stream:
@ -163,14 +119,29 @@ class BasePredictor:
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass
def setup_source(self, source):
if not self.model:
raise Exception("Model not initialized!")
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(source=source,
transforms=getattr(self.model.model, 'transforms', None),
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stride=self.model.stride,
auto=self.model.pt)
self.source_type = self.dataset.source_type
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start")
# setup model
if not self.model:
self.setup_model(model)
# setup source. Run every time predict is called
self.setup_source(source)
# setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -198,7 +169,7 @@ class BasePredictor:
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -237,21 +208,6 @@ class BasePredictor:
self.device = device
self.model.eval()
def check_source(self, source):
source = source if source is not None else self.args.source
webcam, screenshot, from_img = False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
else:
from_img = True
return source, webcam, screenshot, from_img
def show(self, p):
im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows: