Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent a5410ed79e
commit 0609561549
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,10 +3,12 @@
from pathlib import Path from pathlib import Path
import cv2 import cv2
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import ROOT, SETTINGS from ultralytics.yolo.utils import ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
@ -40,6 +42,7 @@ def test_predict_dir():
def test_predict_img(): def test_predict_img():
model = YOLO(MODEL) model = YOLO(MODEL)
img = Image.open(str(SOURCE)) img = Image.open(str(SOURCE))
output = model(source=img, save=True, verbose=True) # PIL output = model(source=img, save=True, verbose=True) # PIL
@ -54,6 +57,16 @@ def test_predict_img():
tens = torch.zeros(320, 640, 3) tens = torch.zeros(320, 640, 3)
output = model(tens.numpy()) output = model(tens.numpy())
assert len(output) == 1, "predict test failed" assert len(output) == 1, "predict test failed"
# test multiple source
imgs = [
SOURCE, # filename
Path(SOURCE), # Path
'https://ultralytics.com/images/zidane.jpg', # URI
cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy
output = model(imgs)
assert len(output) == 6, "predict test failed!"
def test_val(): def test_val():
@ -129,3 +142,28 @@ def test_workflow():
model.val() model.val()
model.predict(SOURCE) model.predict(SOURCE)
model.export(format="onnx", opset=12) # export a model to ONNX format model.export(format="onnx", opset=12) # export a model to ONNX format
def test_predict_callback_and_setup():
def on_predict_batch_end(predictor):
# results -> List[batch_size]
path, _, im0s, _, _ = predictor.batch
# print('on_predict_batch_end', im0s[0].shape)
bs = [predictor.bs for i in range(0, len(path))]
predictor.results = zip(predictor.results, im0s, bs)
model = YOLO("yolov8n.pt")
model.add_callback("on_predict_batch_end", on_predict_batch_end)
dataset = load_inference_source(source=SOURCE, transforms=model.transforms)
bs = dataset.bs # access predictor properties
results = model.predict(dataset, stream=True) # source already setup
for _, (result, im0, bs) in enumerate(results):
print('test_callback', im0.shape)
print('test_callback', bs)
boxes = result.boxes # Boxes object for bbox outputs
print(boxes)
test_predict_img()

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, GPL-3.0 license
from .base import BaseDataset from .base import BaseDataset
from .build import build_classification_dataloader, build_dataloader from .build import build_classification_dataloader, build_dataloader, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset from .dataset_wrappers import MixAndRectDataset

@ -2,11 +2,18 @@
import os import os
import random import random
from pathlib import Path
import numpy as np import numpy as np
import torch import torch
from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
from ..utils import LOGGER, colorstr from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset from .dataset import ClassificationDataset, YOLODataset
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
pin_memory=PIN_MEMORY, pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker, worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True) generator=generator) # or DataLoader(persistent_workers=True)
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, tuple(LOADERS)):
in_memory = True
elif isinstance(source, (list, tuple)):
source = autocast_list(source) # convert all list elements to PIL or np arrays
from_img = True
elif isinstance(source, ((Image.Image, np.ndarray))):
from_img = True
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return source, webcam, screenshot, from_img, in_memory
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
# Dataloader
if in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
elif from_img:
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
else:
dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
setattr(dataset, 'source_type', source_type) # attach source types
return dataset

@ -4,14 +4,16 @@ import glob
import math import math
import os import os
import time import time
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from threading import Thread from threading import Thread
from urllib.parse import urlparse from urllib.parse import urlparse
import cv2 import cv2
import numpy as np import numpy as np
import requests
import torch import torch
from PIL import Image from PIL import Image, ImageOps
from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.yolo.utils.checks import check_requirements from ultralytics.yolo.utils.checks import check_requirements
@dataclass
class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
class LoadStreams: class LoadStreams:
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` # YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
@ -63,6 +72,8 @@ class LoadStreams:
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect self.auto = auto and self.rect
self.transforms = transforms # optional self.transforms = transforms # optional
self.bs = self.__len__()
if not self.rect: if not self.rect:
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.') LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
@ -128,6 +139,7 @@ class LoadScreenshots:
self.mode = 'stream' self.mode = 'stream'
self.frame = 0 self.frame = 0
self.sct = mss.mss() self.sct = mss.mss()
self.bs = 1
# Parse monitor shape # Parse monitor shape
monitor = self.sct.monitors[self.screen] monitor = self.sct.monitors[self.screen]
@ -185,6 +197,7 @@ class LoadImages:
self.auto = auto self.auto = auto
self.transforms = transforms # optional self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos): if any(videos):
self.orientation = None # rotation degrees self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video self._new_video(videos[0]) # new video
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
self.mode = 'image' self.mode = 'image'
# generate fake paths # generate fake paths
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))] self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
self.bs = 1
@staticmethod @staticmethod
def _single_check(im): def _single_check(im):
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
return self return self
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images
"""
files = []
for _, im in enumerate(source):
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
if __name__ == "__main__": if __name__ == "__main__":
img = cv2.imread(str(ROOT / "assets/bus.jpg")) img = cv2.imread(str(ROOT / "assets/bus.jpg"))
dataset = LoadPilAndNumpy(im0=img) dataset = LoadPilAndNumpy(im0=img)

@ -233,6 +233,13 @@ class YOLO:
""" """
return self.model.names return self.model.names
@property
def transforms(self):
"""
Returns transform of the loaded model.
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@staticmethod @staticmethod
def add_callback(event: str, func): def add_callback(event: str, func):
""" """

@ -30,13 +30,13 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import cv2 import cv2
import torch
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams from ultralytics.yolo.data import load_inference_source
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.checks import check_imgsz, check_imshow
from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
@ -76,6 +76,8 @@ class BasePredictor:
if self.args.conf is None: if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25 self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done # Usable if setup is done
self.model = None self.model = None
@ -88,6 +90,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None self.vid_path, self.vid_writer = None, None
self.annotator = None self.annotator = None
self.data_path = None self.data_path = None
self.source_type = None
self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@ -103,53 +106,6 @@ class BasePredictor:
def postprocess(self, preds, img, orig_img, classes=None): def postprocess(self, preds, img, orig_img, classes=None):
return preds return preds
def setup_source(self, source=None):
if not self.model:
raise Exception("setup model before setting up source!")
# source
source, webcam, screenshot, from_img = self.check_source(source)
# model
stride, pt = self.model.stride, self.model.pt
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
# Dataloader
bs = 1 # batch_size
if webcam:
self.args.show = check_imshow(warn=True)
self.dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
bs = len(self.dataset)
elif screenshot:
self.dataset = LoadScreenshots(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
elif from_img:
self.dataset = LoadPilAndNumpy(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None))
else:
self.dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=pt,
transforms=getattr(self.model.model, 'transforms', None),
vid_stride=self.args.vid_stride)
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
self.webcam = webcam
self.screenshot = screenshot
self.from_img = from_img
self.imgsz = imgsz
self.bs = bs
@smart_inference_mode() @smart_inference_mode()
def __call__(self, source=None, model=None, stream=False): def __call__(self, source=None, model=None, stream=False):
if stream: if stream:
@ -163,14 +119,29 @@ class BasePredictor:
for _ in gen: # running CLI inference without accumulating any outputs (do not modify) for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
pass pass
def setup_source(self, source):
if not self.model:
raise Exception("Model not initialized!")
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(source=source,
transforms=getattr(self.model.model, 'transforms', None),
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
stride=self.model.stride,
auto=self.model.pt)
self.source_type = self.dataset.source_type
self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs
def stream_inference(self, source=None, model=None): def stream_inference(self, source=None, model=None):
self.run_callbacks("on_predict_start") self.run_callbacks("on_predict_start")
# setup model # setup model
if not self.model: if not self.model:
self.setup_model(model) self.setup_model(model)
# setup source. Run every time predict is called # setup source every time predict is called
self.setup_source(source) self.setup_source(source if source is not None else self.args.source)
# check if save_dir/ label file exists # check if save_dir/ label file exists
if self.args.save or self.args.save_txt: if self.args.save or self.args.save_txt:
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -198,7 +169,7 @@ class BasePredictor:
with self.dt[2]: with self.dt[2]:
self.results = self.postprocess(preds, im, im0s, self.classes) self.results = self.postprocess(preds, im, im0s, self.classes)
for i in range(len(im)): for i in range(len(im)):
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s)
p = Path(p) p = Path(p)
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
@ -237,21 +208,6 @@ class BasePredictor:
self.device = device self.device = device
self.model.eval() self.model.eval()
def check_source(self, source):
source = source if source is not None else self.args.source
webcam, screenshot, from_img = False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
else:
from_img = True
return source, webcam, screenshot, from_img
def show(self, p): def show(self, p):
im0 = self.annotator.result() im0 = self.annotator.result()
if platform.system() == 'Linux' and p not in self.windows: if platform.system() == 'Linux' and p not in self.windows:

@ -33,7 +33,7 @@ class ClassificationPredictor(BasePredictor):
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
im0 = im0.copy() im0 = im0.copy()
if self.webcam or self.from_img: # batch_size >= 1 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: ' log_string += f'{idx}: '
frame = self.dataset.count frame = self.dataset.count
else: else:

@ -42,7 +42,7 @@ class DetectionPredictor(BasePredictor):
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
im0 = im0.copy() im0 = im0.copy()
if self.webcam or self.from_img: # batch_size >= 1 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: ' log_string += f'{idx}: '
frame = self.dataset.count frame = self.dataset.count
else: else:

@ -43,7 +43,7 @@ class SegmentationPredictor(DetectionPredictor):
if len(im.shape) == 3: if len(im.shape) == 3:
im = im[None] # expand for batch dim im = im[None] # expand for batch dim
self.seen += 1 self.seen += 1
if self.webcam or self.from_img: # batch_size >= 1 if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1
log_string += f'{idx}: ' log_string += f'{idx}: '
frame = self.dataset.count frame = self.dataset.count
else: else:

Loading…
Cancel
Save