Support prediction of list of sources, in-memory dataset and other improvements (#685)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Ayush Chaurasia
2023-01-29 05:44:26 +05:30
committed by GitHub
parent a5410ed79e
commit 0609561549
9 changed files with 174 additions and 73 deletions

View File

@ -4,14 +4,16 @@ import glob
import math
import os
import time
from dataclasses import dataclass
from pathlib import Path
from threading import Thread
from urllib.parse import urlparse
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from PIL import Image, ImageOps
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
@ -19,6 +21,13 @@ from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
from ultralytics.yolo.utils.checks import check_requirements
@dataclass
class SourceTypes:
webcam: bool = False
screenshot: bool = False
from_img: bool = False
class LoadStreams:
# YOLOv8 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
@ -63,6 +72,8 @@ class LoadStreams:
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
self.auto = auto and self.rect
self.transforms = transforms # optional
self.bs = self.__len__()
if not self.rect:
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
@ -128,6 +139,7 @@ class LoadScreenshots:
self.mode = 'stream'
self.frame = 0
self.sct = mss.mss()
self.bs = 1
# Parse monitor shape
monitor = self.sct.monitors[self.screen]
@ -185,6 +197,7 @@ class LoadImages:
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
self.bs = 1
if any(videos):
self.orientation = None # rotation degrees
self._new_video(videos[0]) # new video
@ -276,6 +289,7 @@ class LoadPilAndNumpy:
self.mode = 'image'
# generate fake paths
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
self.bs = 1
@staticmethod
def _single_check(im):
@ -311,6 +325,25 @@ class LoadPilAndNumpy:
return self
def autocast_list(source):
"""
Merges a list of source of different types into a list of numpy arrays or PIL images
"""
files = []
for _, im in enumerate(source):
if isinstance(im, (str, Path)): # filename or uri
files.append(Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im))
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
files.append(im)
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
return files
LOADERS = [LoadStreams, LoadPilAndNumpy, LoadImages, LoadScreenshots]
if __name__ == "__main__":
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
dataset = LoadPilAndNumpy(im0=img)