|
|
|
@ -30,13 +30,13 @@ from collections import defaultdict
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from ultralytics.nn.autobackend import AutoBackend
|
|
|
|
|
from ultralytics.yolo.cfg import get_cfg
|
|
|
|
|
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
|
|
|
|
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
|
|
|
|
from ultralytics.yolo.data import load_inference_source
|
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
|
|
|
|
|
from ultralytics.yolo.utils.files import increment_path
|
|
|
|
|
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
|
|
|
|
|
|
|
|
@ -76,6 +76,8 @@ class BasePredictor:
|
|
|
|
|
if self.args.conf is None:
|
|
|
|
|
self.args.conf = 0.25 # default conf=0.25
|
|
|
|
|
self.done_warmup = False
|
|
|
|
|
if self.args.show:
|
|
|
|
|
self.args.show = check_imshow(warn=True)
|
|
|
|
|
|
|
|
|
|
# Usable if setup is done
|
|
|
|
|
self.model = None
|
|
|
|
@ -88,6 +90,7 @@ class BasePredictor:
|
|
|
|
|
self.vid_path, self.vid_writer = None, None
|
|
|
|
|
self.annotator = None
|
|
|
|
|
self.data_path = None
|
|
|
|
|
self.source_type = None
|
|
|
|
|
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
|
|
|
|
|
callbacks.add_integration_callbacks(self)
|
|
|
|
|
|
|
|
|
@ -103,53 +106,6 @@ class BasePredictor:
|
|
|
|
|
def postprocess(self, preds, img, orig_img, classes=None):
|
|
|
|
|
return preds
|
|
|
|
|
|
|
|
|
|
def setup_source(self, source=None):
|
|
|
|
|
if not self.model:
|
|
|
|
|
raise Exception("setup model before setting up source!")
|
|
|
|
|
# source
|
|
|
|
|
source, webcam, screenshot, from_img = self.check_source(source)
|
|
|
|
|
# model
|
|
|
|
|
stride, pt = self.model.stride, self.model.pt
|
|
|
|
|
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
|
|
|
|
|
|
|
|
|
# Dataloader
|
|
|
|
|
bs = 1 # batch_size
|
|
|
|
|
if webcam:
|
|
|
|
|
self.args.show = check_imshow(warn=True)
|
|
|
|
|
self.dataset = LoadStreams(source,
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
stride=stride,
|
|
|
|
|
auto=pt,
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None),
|
|
|
|
|
vid_stride=self.args.vid_stride)
|
|
|
|
|
bs = len(self.dataset)
|
|
|
|
|
elif screenshot:
|
|
|
|
|
self.dataset = LoadScreenshots(source,
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
stride=stride,
|
|
|
|
|
auto=pt,
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None))
|
|
|
|
|
elif from_img:
|
|
|
|
|
self.dataset = LoadPilAndNumpy(source,
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
stride=stride,
|
|
|
|
|
auto=pt,
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None))
|
|
|
|
|
else:
|
|
|
|
|
self.dataset = LoadImages(source,
|
|
|
|
|
imgsz=imgsz,
|
|
|
|
|
stride=stride,
|
|
|
|
|
auto=pt,
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None),
|
|
|
|
|
vid_stride=self.args.vid_stride)
|
|
|
|
|
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
|
|
|
|
|
|
|
|
|
self.webcam = webcam
|
|
|
|
|
self.screenshot = screenshot
|
|
|
|
|
self.from_img = from_img
|
|
|
|
|
self.imgsz = imgsz
|
|
|
|
|
self.bs = bs
|
|
|
|
|
|
|
|
|
|
@smart_inference_mode()
|
|
|
|
|
def __call__(self, source=None, model=None, stream=False):
|
|
|
|
|
if stream:
|
|
|
|
@ -163,14 +119,29 @@ class BasePredictor:
|
|
|
|
|
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def setup_source(self, source):
|
|
|
|
|
if not self.model:
|
|
|
|
|
raise Exception("Model not initialized!")
|
|
|
|
|
|
|
|
|
|
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
|
|
|
|
self.dataset = load_inference_source(source=source,
|
|
|
|
|
transforms=getattr(self.model.model, 'transforms', None),
|
|
|
|
|
imgsz=self.imgsz,
|
|
|
|
|
vid_stride=self.args.vid_stride,
|
|
|
|
|
stride=self.model.stride,
|
|
|
|
|
auto=self.model.pt)
|
|
|
|
|
self.source_type = self.dataset.source_type
|
|
|
|
|
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
|
|
|
|
|
|
|
|
|
|
def stream_inference(self, source=None, model=None):
|
|
|
|
|
self.run_callbacks("on_predict_start")
|
|
|
|
|
|
|
|
|
|
# setup model
|
|
|
|
|
if not self.model:
|
|
|
|
|
self.setup_model(model)
|
|
|
|
|
# setup source. Run every time predict is called
|
|
|
|
|
self.setup_source(source)
|
|
|
|
|
# setup source every time predict is called
|
|
|
|
|
self.setup_source(source if source is not None else self.args.source)
|
|
|
|
|
|
|
|
|
|
# check if save_dir/ label file exists
|
|
|
|
|
if self.args.save or self.args.save_txt:
|
|
|
|
|
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
|
|
|
@ -198,7 +169,7 @@ class BasePredictor:
|
|
|
|
|
with self.dt[2]:
|
|
|
|
|
self.results = self.postprocess(preds, im, im0s, self.classes)
|
|
|
|
|
for i in range(len(im)):
|
|
|
|
|
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
|
|
|
|
p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
|
|
|
|
|
p = Path(p)
|
|
|
|
|
|
|
|
|
|
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
|
|
|
@ -237,21 +208,6 @@ class BasePredictor:
|
|
|
|
|
self.device = device
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
|
|
|
|
def check_source(self, source):
|
|
|
|
|
source = source if source is not None else self.args.source
|
|
|
|
|
webcam, screenshot, from_img = False, False, False
|
|
|
|
|
if isinstance(source, (str, int, Path)): # int for local usb carame
|
|
|
|
|
source = str(source)
|
|
|
|
|
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
|
|
|
|
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
|
|
|
|
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
|
|
|
|
screenshot = source.lower().startswith('screen')
|
|
|
|
|
if is_url and is_file:
|
|
|
|
|
source = check_file(source) # download
|
|
|
|
|
else:
|
|
|
|
|
from_img = True
|
|
|
|
|
return source, webcam, screenshot, from_img
|
|
|
|
|
|
|
|
|
|
def show(self, p):
|
|
|
|
|
im0 = self.annotator.result()
|
|
|
|
|
if platform.system() == 'Linux' and p not in self.windows:
|
|
|
|
|