Predictor support (#65)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>single_channel
parent
479992093c
commit
e6737f1207
After Width: | Height: | Size: 476 KiB |
After Width: | Height: | Size: 165 KiB |
@ -0,0 +1,254 @@
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
|
||||
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.mode = 'stream'
|
||||
self.img_size = img_size
|
||||
self.stride = stride
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
||||
n = len(sources)
|
||||
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
|
||||
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
|
||||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f'{i + 1}/{n}: {s}... '
|
||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy
|
||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0:
|
||||
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
|
||||
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
|
||||
cap = cv2.VideoCapture(s)
|
||||
assert cap.isOpened(), f'{st}Failed to open {s}'
|
||||
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
|
||||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||
|
||||
_, self.imgs[i] = cap.read() # guarantee first frame
|
||||
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
|
||||
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||
self.threads[i].start()
|
||||
LOGGER.info('') # newline
|
||||
|
||||
# check for common shapes
|
||||
s = np.stack([LetterBox(img_size, auto, stride=stride)(image=x).shape for x in self.imgs])
|
||||
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
|
||||
self.auto = auto and self.rect
|
||||
self.transforms = transforms # optional
|
||||
if not self.rect:
|
||||
LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
|
||||
|
||||
def update(self, i, cap, stream):
|
||||
# Read stream `i` frames in daemon thread
|
||||
n, f = 0, self.frames[i] # frame number, frame array
|
||||
while cap.isOpened() and n < f:
|
||||
n += 1
|
||||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||
if n % self.vid_stride == 0:
|
||||
success, im = cap.retrieve()
|
||||
if success:
|
||||
self.imgs[i] = im
|
||||
else:
|
||||
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
|
||||
self.imgs[i] = np.zeros_like(self.imgs[i])
|
||||
cap.open(stream) # re-open stream if signal was lost
|
||||
time.sleep(0.0) # wait time
|
||||
|
||||
def __iter__(self):
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.count += 1
|
||||
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
|
||||
cv2.destroyAllWindows()
|
||||
raise StopIteration
|
||||
|
||||
im0 = self.imgs.copy()
|
||||
if self.transforms:
|
||||
im = np.stack([self.transforms(x) for x in im0]) # transforms
|
||||
else:
|
||||
im = np.stack([LetterBox(self.img_size, self.auto, stride=self.stride)(image=x) for x in im0])
|
||||
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
|
||||
return self.sources, im, im0, None, ''
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||
|
||||
|
||||
class LoadScreenshots:
|
||||
# YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
|
||||
def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
|
||||
# source = [screen_number left top width height] (pixels)
|
||||
check_requirements('mss')
|
||||
import mss
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
if len(params) == 1:
|
||||
self.screen = int(params[0])
|
||||
elif len(params) == 4:
|
||||
left, top, width, height = (int(x) for x in params)
|
||||
elif len(params) == 5:
|
||||
self.screen, left, top, width, height = (int(x) for x in params)
|
||||
self.img_size = img_size
|
||||
self.stride = stride
|
||||
self.transforms = transforms
|
||||
self.auto = auto
|
||||
self.mode = 'stream'
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
# mss screen capture: get raw pixels from the screen as np array
|
||||
im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(im0) # transforms
|
||||
else:
|
||||
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
self.frame += 1
|
||||
return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
|
||||
|
||||
|
||||
class LoadImages:
|
||||
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
|
||||
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
|
||||
if isinstance(path, str) and Path(path).suffix == ".txt": # *.txt file with img/vid/dir on each line
|
||||
path = Path(path).read_text().rsplit()
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
p = str(Path(p).resolve())
|
||||
if '*' in p:
|
||||
files.extend(sorted(glob.glob(p, recursive=True))) # glob
|
||||
elif os.path.isdir(p):
|
||||
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
|
||||
elif os.path.isfile(p):
|
||||
files.append(p) # files
|
||||
else:
|
||||
raise FileNotFoundError(f'{p} does not exist')
|
||||
|
||||
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.img_size = img_size
|
||||
self.stride = stride
|
||||
self.files = images + videos
|
||||
self.nf = ni + nv # number of files
|
||||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = 'image'
|
||||
self.auto = auto
|
||||
self.transforms = transforms # optional
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
if any(videos):
|
||||
self._new_video(videos[0]) # new video
|
||||
else:
|
||||
self.cap = None
|
||||
assert self.nf > 0, f'No images or videos found in {p}. ' \
|
||||
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
|
||||
|
||||
def __iter__(self):
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.count == self.nf:
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
|
||||
if self.video_flag[self.count]:
|
||||
# Read video
|
||||
self.mode = 'video'
|
||||
for _ in range(self.vid_stride):
|
||||
self.cap.grab()
|
||||
ret_val, im0 = self.cap.retrieve()
|
||||
while not ret_val:
|
||||
self.count += 1
|
||||
self.cap.release()
|
||||
if self.count == self.nf: # last video
|
||||
raise StopIteration
|
||||
path = self.files[self.count]
|
||||
self._new_video(path)
|
||||
ret_val, im0 = self.cap.read()
|
||||
|
||||
self.frame += 1
|
||||
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
|
||||
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
|
||||
|
||||
else:
|
||||
# Read image
|
||||
self.count += 1
|
||||
im0 = cv2.imread(path) # BGR
|
||||
assert im0 is not None, f'Image Not Found {path}'
|
||||
s = f'image {self.count}/{self.nf} {path}: '
|
||||
|
||||
if self.transforms:
|
||||
im = self.transforms(im0) # transforms
|
||||
else:
|
||||
im = LetterBox(self.img_size, self.auto, stride=self.stride)(image=im0)
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
|
||||
return path, im, im0, self.cap, s
|
||||
|
||||
def _new_video(self, path):
|
||||
# Create a new video capture object
|
||||
self.frame = 0
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
||||
self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
|
||||
# self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
|
||||
|
||||
def _cv2_rotate(self, im):
|
||||
# Rotate a cv2 video manually
|
||||
if self.orientation == 0:
|
||||
return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
|
||||
elif self.orientation == 180:
|
||||
return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||
elif self.orientation == 90:
|
||||
return cv2.rotate(im, cv2.ROTATE_180)
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
return self.nf # number of files
|
@ -0,0 +1,201 @@
|
||||
# predictor engine by Ultralytics
|
||||
"""
|
||||
Run prection on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
Usage - sources:
|
||||
$ yolo task=... mode=predict model=s.pt --source 0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/Zgi9g1ksQHc' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
|
||||
Usage - formats:
|
||||
$ yolo task=... mode=predict --weights yolov5s.pt # PyTorch
|
||||
yolov5s.torchscript # TorchScript
|
||||
yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
|
||||
yolov5s_openvino_model # OpenVINO
|
||||
yolov5s.engine # TensorRT
|
||||
yolov5s.mlmodel # CoreML (macOS-only)
|
||||
yolov5s_saved_model # TensorFlow SavedModel
|
||||
yolov5s.pb # TensorFlow GraphDef
|
||||
yolov5s.tflite # TensorFlow Lite
|
||||
yolov5s_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolov5s_paddle_model # PaddlePaddle
|
||||
"""
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS, check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imshow
|
||||
from ultralytics.yolo.utils.configs import get_config
|
||||
from ultralytics.yolo.utils.files import increment_path
|
||||
from ultralytics.yolo.utils.modeling.autobackend import AutoBackend
|
||||
from ultralytics.yolo.utils.plotting import Annotator
|
||||
from ultralytics.yolo.utils.torch_utils import check_img_size, select_device, smart_inference_mode
|
||||
|
||||
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml"
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = get_config(config, overrides)
|
||||
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
|
||||
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.done_setup = False
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.view_img = None
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
|
||||
def preprocess(self, img):
|
||||
pass
|
||||
|
||||
def get_annotator(self, img):
|
||||
raise NotImplementedError("get_annotator function needs to be implemented")
|
||||
|
||||
def write_results(self, pred, batch, print_string):
|
||||
raise NotImplementedError("print_results function needs to be implemented")
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
def setup(self, source=None, model=None):
|
||||
# source
|
||||
source = str(source or self.args.source)
|
||||
self.save_img = not self.args.nosave and not source.endswith('.txt')
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
# data
|
||||
if self.data:
|
||||
if self.data.endswith(".yaml"):
|
||||
self.data = check_dataset_yaml(self.data)
|
||||
else:
|
||||
self.data = check_dataset(self.data)
|
||||
|
||||
# model
|
||||
device = select_device(self.args.device)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half) # NOTE: not passing data
|
||||
stride, pt = model.stride, model.pt
|
||||
imgsz = check_img_size(self.args.img_size, s=stride) # check image size
|
||||
|
||||
# Dataloader
|
||||
bs = 1 # batch_size
|
||||
if webcam:
|
||||
self.view_img = check_imshow(warn=True)
|
||||
self.dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
|
||||
bs = len(self.dataset)
|
||||
elif screenshot:
|
||||
self.dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
|
||||
else:
|
||||
self.dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=self.args.vid_stride)
|
||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
||||
|
||||
self.model = model
|
||||
self.webcam = webcam
|
||||
self.screenshot = screenshot
|
||||
self.imgsz = imgsz
|
||||
self.done_setup = True
|
||||
self.device = device
|
||||
|
||||
return model
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None):
|
||||
if not self.done_setup:
|
||||
model = self.setup(source, model)
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
for batch in self.dataset:
|
||||
path, im, im0s, vid_cap, s = batch
|
||||
visualize = increment_path(self.save_dir / Path(path).stem, mkdir=True) if self.args.visualize else False
|
||||
with self.dt[0]:
|
||||
im = self.preprocess(im)
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
|
||||
# Inference
|
||||
with self.dt[1]:
|
||||
preds = model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
preds = self.postprocess(preds, im, im0s)
|
||||
|
||||
for i in range(len(im)):
|
||||
if self.webcam:
|
||||
path, im0s = path[i], im0s[i]
|
||||
p = Path(path)
|
||||
s += self.write_results(i, preds, (p, im, im0s))
|
||||
|
||||
if self.args.view_img:
|
||||
self.show(p)
|
||||
|
||||
if self.save_img:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(
|
||||
f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}'
|
||||
% t)
|
||||
if self.args.save_txt or self.save_img:
|
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
if platform.system() == 'Linux' and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
|
||||
cv2.imshow(str(p), im0)
|
||||
cv2.waitKey(1) # 1 millisecond
|
||||
|
||||
def save_preds(self, vid_cap, idx, save_path):
|
||||
im0 = self.annotator.result()
|
||||
# save imgs
|
||||
if self.dataset.mode == 'image':
|
||||
cv2.imwrite(save_path, im0)
|
||||
else: # 'video' or 'stream'
|
||||
if self.vid_path[idx] != save_path: # new video
|
||||
self.vid_path[idx] = save_path
|
||||
if isinstance(self.vid_writer[idx], cv2.VideoWriter):
|
||||
self.vid_writer[idx].release() # release previous video writer
|
||||
if vid_cap: # video
|
||||
fps = vid_cap.get(cv2.CAP_PROP_FPS)
|
||||
w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
else: # stream
|
||||
fps, w, h = 30, im0.shape[1], im0.shape[0]
|
||||
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
|
||||
self.vid_writer[idx] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
self.vid_writer[idx].write(im0)
|
@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
|
||||
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
|
||||
"""
|
||||
Accepts yaml file name or DictConfig containing experiment configuration.
|
||||
Returns training args namespace
|
||||
:param config: Optional file name or DictConfig object
|
||||
"""
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
elif isinstance(config, Dict):
|
||||
config = OmegaConf.create(config)
|
||||
# override
|
||||
if isinstance(overrides, str):
|
||||
overrides = OmegaConf.load(overrides)
|
||||
elif isinstance(overrides, Dict):
|
||||
overrides = OmegaConf.create(overrides)
|
||||
|
||||
return OmegaConf.merge(config, overrides)
|
@ -0,0 +1,68 @@
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
|
||||
def get_annotator(self, img):
|
||||
return Annotator(img, example=str(self.model.names), pil=True)
|
||||
|
||||
def preprocess(self, img):
|
||||
img = torch.Tensor(img).to(self.model.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
return img
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.cound
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
|
||||
self.data_path = p
|
||||
# save_path = str(self.save_dir / p.name) # im.jpg
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
prob = preds[idx]
|
||||
# Print results
|
||||
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
|
||||
|
||||
# write
|
||||
text = '\n'.join(f'{prob[j]:.2f} {self.model.names[j]}' for j in top5i)
|
||||
if self.save_img or self.args.view_img: # Add bbox to image
|
||||
self.annotator.text((32, 32), text, txt_color=(255, 255, 255))
|
||||
if self.args.save_txt: # Write to file
|
||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||
f.write(text + '\n')
|
||||
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "squeezenet1_0"
|
||||
sz = cfg.img_size
|
||||
if type(sz) != int: # recieved listConfig
|
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||
else:
|
||||
cfg.img_size = [sz, sz]
|
||||
predictor = ClassificationPredictor(cfg)
|
||||
predictor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
@ -1,2 +1,3 @@
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor, predict
|
||||
from ultralytics.yolo.v8.detect.train import DetectionTrainer, train
|
||||
from ultralytics.yolo.v8.detect.val import DetectionValidator, val
|
||||
|
@ -0,0 +1,97 @@
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
|
||||
def get_annotator(self, img):
|
||||
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
|
||||
|
||||
def preprocess(self, img):
|
||||
img = torch.from_numpy(img).to(self.model.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
img /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return img
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf_thres,
|
||||
self.args.iou_thres,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det)
|
||||
|
||||
for i, pred in enumerate(preds):
|
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
|
||||
return preds
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
|
||||
self.data_path = p
|
||||
# save_path = str(self.save_dir / p.name) # im.jpg
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
det = preds[idx]
|
||||
if len(det) == 0:
|
||||
return log_string
|
||||
for c in det[:, 5].unique():
|
||||
n = (det[:, 5] == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||
|
||||
# write
|
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
if self.args.save_txt: # Write to file
|
||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
line = (cls, *xywh, conf) if self.args.save_conf else (cls, *xywh) # label format
|
||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
if self.save_img or self.args.save_crop or self.args.view_img: # Add bbox to image
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
imc = im0.copy()
|
||||
save_one_box(xyxy,
|
||||
imc,
|
||||
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',
|
||||
BGR=True)
|
||||
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "n.pt"
|
||||
sz = cfg.img_size
|
||||
if type(sz) != int: # recieved listConfig
|
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||
else:
|
||||
cfg.img_size = [sz, sz]
|
||||
predictor = DetectionPredictor(cfg)
|
||||
predictor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
@ -1,2 +1,3 @@
|
||||
from ultralytics.yolo.v8.segment.predict import SegmentationPredictor, predict
|
||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train
|
||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator, val
|
||||
|
@ -0,0 +1,115 @@
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import ROOT, ops
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
from ..detect.predict import DetectionPredictor
|
||||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
masks = []
|
||||
if len(preds) == 2: # eval
|
||||
p, proto, = preds
|
||||
else: # len(3) train
|
||||
p, proto, _ = preds
|
||||
# TODO: filter by classes
|
||||
p = ops.non_max_suppression(p,
|
||||
self.args.conf_thres,
|
||||
self.args.iou_thres,
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nm=32)
|
||||
for i, pred in enumerate(p):
|
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||
if not len(pred):
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
|
||||
else:
|
||||
masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
|
||||
return (p, masks)
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
if self.webcam: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
|
||||
self.data_path = p
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
preds, masks = preds
|
||||
det = preds[idx]
|
||||
if len(det) == 0:
|
||||
return log_string
|
||||
# Segments
|
||||
mask = masks[idx]
|
||||
if self.args.save_txt:
|
||||
segments = [
|
||||
ops.scale_segments(im0.shape if self.arg.retina_masks else im.shape[2:], x, im0.shape, normalize=True)
|
||||
for x in reversed(ops.masks2segments(mask))]
|
||||
|
||||
# Print results
|
||||
for c in det[:, 5].unique():
|
||||
n = (det[:, 5] == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||
|
||||
# Mask plotting
|
||||
self.annotator.masks(
|
||||
mask,
|
||||
colors=[colors(x, True) for x in det[:, 5]],
|
||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
|
||||
255 if self.args.retina_masks else im[idx])
|
||||
|
||||
# Write results
|
||||
for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
|
||||
if self.args.save_txt: # Write to file
|
||||
seg = segments[j].reshape(-1) # (n,2) to (n*2)
|
||||
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format
|
||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
if self.save_img or self.args.save_crop or self.args.view_img:
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
# annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
|
||||
if self.args.save_crop:
|
||||
imc = im0.copy()
|
||||
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True)
|
||||
|
||||
return log_string
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "n.pt"
|
||||
sz = cfg.img_size
|
||||
if type(sz) != int: # recieved listConfig
|
||||
cfg.img_size = [sz[0], sz[0]] if len(cfg.img_size) == 1 else [sz[0], sz[1]] # expand
|
||||
else:
|
||||
cfg.img_size = [sz, sz]
|
||||
predictor = SegmentationPredictor(cfg)
|
||||
predictor()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
predict()
|
Loading…
Reference in new issue