Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-29 05:44:26 +05:30
committed by GitHub
parent a5410ed79e
commit 0609561549
9 changed files with 174 additions and 73 deletions

View File

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from .base import BaseDataset
from .build import build_classification_dataloader, build_dataloader
from .build import build_classification_dataloader, build_dataloader, load_inference_source
from .dataset import ClassificationDataset, SemanticDataset, YOLODataset
from .dataset_wrappers import MixAndRectDataset

View File

@ -2,11 +2,18 @@
import os
import random
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, dataloader, distributed
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
LoadStreams, SourceTypes, autocast_list)
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
from ultralytics.yolo.utils.checks import check_file
from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
pin_memory=PIN_MEMORY,
worker_init_fn=seed_worker,
generator=generator) # or DataLoader(persistent_workers=True)
def check_source(source):
webcam, screenshot, from_img, in_memory = False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb carame
source = str(source)
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, tuple(LOADERS)):
in_memory = True
elif isinstance(source, (list, tuple)):
source = autocast_list(source) # convert all list elements to PIL or np arrays
from_img = True
elif isinstance(source, ((Image.Image, np.ndarray))):
from_img = True
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return source, webcam, screenshot, from_img, in_memory
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
"""
# source
source, webcam, screenshot, from_img, in_memory = check_source(source)
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
# Dataloader
if in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
elif from_img:
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
else:
dataset = LoadImages(source,
imgsz=imgsz,
stride=stride,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
setattr(dataset, 'source_type', source_type) # attach source types
return dataset

View File

@ -4,14 +4,16 @@ import glob
import math
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from PIL import Image, ImageOps
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.yolo.utils.checks import check_requirements
@dataclass
class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
class LoadStreams:
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
@ -63,6 +72,8 @@ class LoadStreams:
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
self.bs = self.__len__()
if not self.rect:
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
@ -128,6 +139,7 @@ class LoadScreenshots:
self.mode = 'stream'
self.frame = 0
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
@ -185,6 +197,7 @@ class LoadImages:
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
self.mode = 'image'
# generate fake paths
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
self.bs = 1
@staticmethod
def _single_check(im):
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
return self
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images
"""
files = []
for _, im in enumerate(source):
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
if __name__ == "__main__":
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
dataset = LoadPilAndNumpy(im0=img)