New YOLOv8 Results()
class for prediction outputs (#314)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Viet Nhat Thai <60825385+vietnhatthai@users.noreply.github.com> Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>
This commit is contained in:
@ -11,10 +11,11 @@ from urllib.parse import urlparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
|
||||
@ -36,7 +37,7 @@ class LoadStreams:
|
||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy
|
||||
import pafy # noqa
|
||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0:
|
||||
@ -109,7 +110,7 @@ class LoadScreenshots:
|
||||
def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
|
||||
# source = [screen_number left top width height] (pixels)
|
||||
check_requirements('mss')
|
||||
import mss
|
||||
import mss # noqa
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
@ -254,3 +255,58 @@ class LoadImages:
|
||||
|
||||
def __len__(self):
|
||||
return self.nf # number of files
|
||||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
|
||||
def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None):
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
self.im0 = [self._single_check(im) for im in im0]
|
||||
self.imgsz = imgsz
|
||||
self.stride = stride
|
||||
self.auto = auto
|
||||
self.transforms = transforms
|
||||
self.mode = 'image'
|
||||
# generate fake paths
|
||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||
if isinstance(im, Image.Image):
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
||||
def _single_preprocess(self, im, auto):
|
||||
if self.transforms:
|
||||
im = self.transforms(im) # transforms
|
||||
else:
|
||||
im = LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=im)
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
return len(self.im0)
|
||||
|
||||
def __next__(self):
|
||||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto
|
||||
im = [self._single_preprocess(im, auto) for im in self.im0]
|
||||
im = np.stack(im, 0) if len(im) > 1 else im[0][None]
|
||||
self.count += 1
|
||||
return self.paths, im, self.im0, None, ''
|
||||
|
||||
def __iter__(self):
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
||||
dataset = LoadPilAndNumpy(im0=img)
|
||||
for d in dataset:
|
||||
print(d[0])
|
||||
|
Reference in New Issue
Block a user