diff --git a/tests/test_python.py b/tests/test_python.py index df5bb3a..3ecf340 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -3,10 +3,12 @@ from pathlib import Path import cv2 +import numpy as np import torch from PIL import Image from ultralytics import YOLO +from ultralytics.yolo.data.build import load_inference_source from ultralytics.yolo.utils import ROOT, SETTINGS MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt' @@ -40,6 +42,7 @@ def test_predict_dir(): def test_predict_img(): + model = YOLO(MODEL) img = Image.open(str(SOURCE)) output = model(source=img, save=True, verbose=True) # PIL @@ -54,6 +57,16 @@ def test_predict_img(): tens = torch.zeros(320, 640, 3) output = model(tens.numpy()) assert len(output) == 1, "predict test failed" + # test multiple source + imgs = [ + SOURCE, # filename + Path(SOURCE), # Path + 'https://ultralytics.com/images/zidane.jpg', # URI + cv2.imread(str(SOURCE)), # OpenCV + Image.open(SOURCE), # PIL + np.zeros((320, 640, 3))] # numpy + output = model(imgs) + assert len(output) == 6, "predict test failed!" def test_val(): @@ -129,3 +142,28 @@ def test_workflow(): model.val() model.predict(SOURCE) model.export(format="onnx", opset=12) # export a model to ONNX format + + +def test_predict_callback_and_setup(): + + def on_predict_batch_end(predictor): + # results -> List[batch_size] + path, _, im0s, _, _ = predictor.batch + # print('on_predict_batch_end', im0s[0].shape) + bs = [predictor.bs for i in range(0, len(path))] + predictor.results = zip(predictor.results, im0s, bs) + + model = YOLO("yolov8n.pt") + model.add_callback("on_predict_batch_end", on_predict_batch_end) + + dataset = load_inference_source(source=SOURCE, transforms=model.transforms) + bs = dataset.bs # access predictor properties + results = model.predict(dataset, stream=True) # source already setup + for _, (result, im0, bs) in enumerate(results): + print('test_callback', im0.shape) + print('test_callback', bs) + boxes = result.boxes # Boxes object for bbox outputs + print(boxes) + + +test_predict_img() diff --git a/ultralytics/yolo/data/__init__.py b/ultralytics/yolo/data/__init__.py index ebf4293..47c7e69 100644 --- a/ultralytics/yolo/data/__init__.py +++ b/ultralytics/yolo/data/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license from .base import BaseDataset -from .build import build_classification_dataloader, build_dataloader +from .build import build_classification_dataloader, build_dataloader, load_inference_source from .dataset import ClassificationDataset, SemanticDataset, YOLODataset from .dataset_wrappers import MixAndRectDataset diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index 8dee0b3..762f964 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -2,11 +2,18 @@ import os import random +from pathlib import Path import numpy as np import torch +from PIL import Image from torch.utils.data import DataLoader, dataloader, distributed +from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, + LoadStreams, SourceTypes, autocast_list) +from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS +from ultralytics.yolo.utils.checks import check_file + from ..utils import LOGGER, colorstr from ..utils.torch_utils import torch_distributed_zero_first from .dataset import ClassificationDataset, YOLODataset @@ -123,3 +130,63 @@ def build_classification_dataloader(path, pin_memory=PIN_MEMORY, worker_init_fn=seed_worker, generator=generator) # or DataLoader(persistent_workers=True) + + +def check_source(source): + webcam, screenshot, from_img, in_memory = False, False, False, False + if isinstance(source, (str, int, Path)): # int for local usb carame + source = str(source) + is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')) + webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) + screenshot = source.lower().startswith('screen') + if is_url and is_file: + source = check_file(source) # download + elif isinstance(source, tuple(LOADERS)): + in_memory = True + elif isinstance(source, (list, tuple)): + source = autocast_list(source) # convert all list elements to PIL or np arrays + from_img = True + elif isinstance(source, ((Image.Image, np.ndarray))): + from_img = True + else: + raise Exception( + "Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict") + + return source, webcam, screenshot, from_img, in_memory + + +def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True): + """ + TODO: docs + """ + # source + source, webcam, screenshot, from_img, in_memory = check_source(source) + source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type + + # Dataloader + if in_memory: + dataset = source + elif webcam: + dataset = LoadStreams(source, + imgsz=imgsz, + stride=stride, + auto=auto, + transforms=transforms, + vid_stride=vid_stride) + + elif screenshot: + dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) + elif from_img: + dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms) + else: + dataset = LoadImages(source, + imgsz=imgsz, + stride=stride, + auto=auto, + transforms=transforms, + vid_stride=vid_stride) + + setattr(dataset, 'source_type', source_type) # attach source types + + return dataset diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 5338eef..0b01744 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -4,14 +4,16 @@ import glob import math import os import time +from dataclasses import dataclass from pathlib import Path from threading import Thread from urllib.parse import urlparse import cv2 import numpy as np +import requests import torch -from PIL import Image +from PIL import Image, ImageOps from ultralytics.yolo.data.augment import LetterBox from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS @@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops from ultralytics.yolo.utils.checks import check_requirements +@dataclass +class SourceTypes: + webcam: bool = False + screenshot: bool = False + from_img: bool = False + + class LoadStreams: # YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): @@ -63,6 +72,8 @@ class LoadStreams: self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal self.auto = auto and self.rect self.transforms = transforms # optional + self.bs = self.__len__() + if not self.rect: LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.') @@ -128,6 +139,7 @@ class LoadScreenshots: self.mode = 'stream' self.frame = 0 self.sct = mss.mss() + self.bs = 1 # Parse monitor shape monitor = self.sct.monitors[self.screen] @@ -185,6 +197,7 @@ class LoadImages: self.auto = auto self.transforms = transforms # optional self.vid_stride = vid_stride # video frame-rate stride + self.bs = 1 if any(videos): self.orientation = None # rotation degrees self._new_video(videos[0]) # new video @@ -276,6 +289,7 @@ class LoadPilAndNumpy: self.mode = 'image' # generate fake paths self.paths = [f"image{i}.jpg" for i in range(len(self.im0))] + self.bs = 1 @staticmethod def _single_check(im): @@ -311,6 +325,25 @@ class LoadPilAndNumpy: return self +def autocast_list(source): + """ + Merges a list of source of different types into a list of numpy arrays or PIL images + """ + files = [] + for _, im in enumerate(source): + if isinstance(im, (str, Path)): # filename or uri + files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im)) + elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image + files.append(im) + else: + raise Exception( + "Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict") + + return files + + +LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots] + if __name__ == "__main__": img = cv2.imread(str(ROOT / "assets/bus.jpg")) dataset = LoadPilAndNumpy(im0=img) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 67cccaf..ad876ba 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -233,6 +233,13 @@ class YOLO: """ return self.model.names + @property + def transforms(self): + """ + Returns transform of the loaded model. + """ + return self.model.transforms if hasattr(self.model, 'transforms') else None + @staticmethod def add_callback(event: str, func): """ diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index b83fd2a..e93578b 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -30,13 +30,13 @@ from collections import defaultdict from pathlib import Path import cv2 +import torch from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.cfg import get_cfg -from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams -from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS +from ultralytics.yolo.data import load_inference_source from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, SETTINGS, callbacks, colorstr, ops -from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow +from ultralytics.yolo.utils.checks import check_imgsz, check_imshow from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode @@ -76,6 +76,8 @@ class BasePredictor: if self.args.conf is None: self.args.conf = 0.25 # default conf=0.25 self.done_warmup = False + if self.args.show: + self.args.show = check_imshow(warn=True) # Usable if setup is done self.model = None @@ -88,6 +90,7 @@ class BasePredictor: self.vid_path, self.vid_writer = None, None self.annotator = None self.data_path = None + self.source_type = None self.callbacks = defaultdict(list, callbacks.default_callbacks) # add callbacks callbacks.add_integration_callbacks(self) @@ -103,53 +106,6 @@ class BasePredictor: def postprocess(self, preds, img, orig_img, classes=None): return preds - def setup_source(self, source=None): - if not self.model: - raise Exception("setup model before setting up source!") - # source - source, webcam, screenshot, from_img = self.check_source(source) - # model - stride, pt = self.model.stride, self.model.pt - imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size - - # Dataloader - bs = 1 # batch_size - if webcam: - self.args.show = check_imshow(warn=True) - self.dataset = LoadStreams(source, - imgsz=imgsz, - stride=stride, - auto=pt, - transforms=getattr(self.model.model, 'transforms', None), - vid_stride=self.args.vid_stride) - bs = len(self.dataset) - elif screenshot: - self.dataset = LoadScreenshots(source, - imgsz=imgsz, - stride=stride, - auto=pt, - transforms=getattr(self.model.model, 'transforms', None)) - elif from_img: - self.dataset = LoadPilAndNumpy(source, - imgsz=imgsz, - stride=stride, - auto=pt, - transforms=getattr(self.model.model, 'transforms', None)) - else: - self.dataset = LoadImages(source, - imgsz=imgsz, - stride=stride, - auto=pt, - transforms=getattr(self.model.model, 'transforms', None), - vid_stride=self.args.vid_stride) - self.vid_path, self.vid_writer = [None] * bs, [None] * bs - - self.webcam = webcam - self.screenshot = screenshot - self.from_img = from_img - self.imgsz = imgsz - self.bs = bs - @smart_inference_mode() def __call__(self, source=None, model=None, stream=False): if stream: @@ -163,14 +119,29 @@ class BasePredictor: for _ in gen: # running CLI inference without accumulating any outputs (do not modify) pass + def setup_source(self, source): + if not self.model: + raise Exception("Model not initialized!") + + self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size + self.dataset = load_inference_source(source=source, + transforms=getattr(self.model.model, 'transforms', None), + imgsz=self.imgsz, + vid_stride=self.args.vid_stride, + stride=self.model.stride, + auto=self.model.pt) + self.source_type = self.dataset.source_type + self.vid_path, self.vid_writer = [None] * self.dataset.bs, [None] * self.dataset.bs + def stream_inference(self, source=None, model=None): self.run_callbacks("on_predict_start") # setup model if not self.model: self.setup_model(model) - # setup source. Run every time predict is called - self.setup_source(source) + # setup source every time predict is called + self.setup_source(source if source is not None else self.args.source) + # check if save_dir/ label file exists if self.args.save or self.args.save_txt: (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) @@ -198,7 +169,7 @@ class BasePredictor: with self.dt[2]: self.results = self.postprocess(preds, im, im0s, self.classes) for i in range(len(im)): - p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s) + p, im0 = (path[i], im0s[i]) if self.source_type.webcam or self.source_type.from_img else (path, im0s) p = Path(p) if self.args.verbose or self.args.save or self.args.save_txt or self.args.show: @@ -237,21 +208,6 @@ class BasePredictor: self.device = device self.model.eval() - def check_source(self, source): - source = source if source is not None else self.args.source - webcam, screenshot, from_img = False, False, False - if isinstance(source, (str, int, Path)): # int for local usb carame - source = str(source) - is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) - is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://')) - webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file) - screenshot = source.lower().startswith('screen') - if is_url and is_file: - source = check_file(source) # download - else: - from_img = True - return source, webcam, screenshot, from_img - def show(self, p): im0 = self.annotator.result() if platform.system() == 'Linux' and p not in self.windows: diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index d50bdba..5c9a086 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -33,7 +33,7 @@ class ClassificationPredictor(BasePredictor): im = im[None] # expand for batch dim self.seen += 1 im0 = im0.copy() - if self.webcam or self.from_img: # batch_size >= 1 + if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count else: diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index 802e69b..5e3b4c8 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -42,7 +42,7 @@ class DetectionPredictor(BasePredictor): im = im[None] # expand for batch dim self.seen += 1 im0 = im0.copy() - if self.webcam or self.from_img: # batch_size >= 1 + if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count else: diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index fd52a18..49302c6 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -43,7 +43,7 @@ class SegmentationPredictor(DetectionPredictor): if len(im.shape) == 3: im = im[None] # expand for batch dim self.seen += 1 - if self.webcam or self.from_img: # batch_size >= 1 + if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 log_string += f'{idx}: ' frame = self.dataset.count else: