New YOLOv8 Results()
class for prediction outputs (#314)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Viet Nhat Thai <60825385+vietnhatthai@users.noreply.github.com> Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>
This commit is contained in:
@ -11,10 +11,11 @@ from urllib.parse import urlparse
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.yolo.data.augment import LetterBox
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import LOGGER, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils import LOGGER, ROOT, is_colab, is_kaggle, ops
|
||||
from ultralytics.yolo.utils.checks import check_requirements
|
||||
|
||||
|
||||
@ -36,7 +37,7 @@ class LoadStreams:
|
||||
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
|
||||
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
|
||||
import pafy
|
||||
import pafy # noqa
|
||||
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0:
|
||||
@ -109,7 +110,7 @@ class LoadScreenshots:
|
||||
def __init__(self, source, imgsz=640, stride=32, auto=True, transforms=None):
|
||||
# source = [screen_number left top width height] (pixels)
|
||||
check_requirements('mss')
|
||||
import mss
|
||||
import mss # noqa
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
@ -254,3 +255,58 @@ class LoadImages:
|
||||
|
||||
def __len__(self):
|
||||
return self.nf # number of files
|
||||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
|
||||
def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None):
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
self.im0 = [self._single_check(im) for im in im0]
|
||||
self.imgsz = imgsz
|
||||
self.stride = stride
|
||||
self.auto = auto
|
||||
self.transforms = transforms
|
||||
self.mode = 'image'
|
||||
# generate fake paths
|
||||
self.paths = [f"image{i}.jpg" for i in range(len(self.im0))]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im):
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||
if isinstance(im, Image.Image):
|
||||
im = np.asarray(im)[:, :, ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
||||
def _single_preprocess(self, im, auto):
|
||||
if self.transforms:
|
||||
im = self.transforms(im) # transforms
|
||||
else:
|
||||
im = LetterBox(self.imgsz, auto=auto, stride=self.stride)(image=im)
|
||||
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
return im
|
||||
|
||||
def __len__(self):
|
||||
return len(self.im0)
|
||||
|
||||
def __next__(self):
|
||||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto
|
||||
im = [self._single_preprocess(im, auto) for im in self.im0]
|
||||
im = np.stack(im, 0) if len(im) > 1 else im[0][None]
|
||||
self.count += 1
|
||||
return self.paths, im, self.im0, None, ''
|
||||
|
||||
def __iter__(self):
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread(str(ROOT / "assets/bus.jpg"))
|
||||
dataset = LoadPilAndNumpy(im0=img)
|
||||
for d in dataset:
|
||||
print(d[0])
|
||||
|
@ -54,8 +54,8 @@ class YOLO:
|
||||
# Load or create new YOLO model
|
||||
{'.pt': self._load, '.yaml': self._new}[Path(model).suffix](model)
|
||||
|
||||
def __call__(self, source, **kwargs):
|
||||
return self.predict(source, **kwargs)
|
||||
def __call__(self, source=None, stream=False, verbose=False, **kwargs):
|
||||
return self.predict(source, stream, verbose, **kwargs)
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
"""
|
||||
@ -111,13 +111,20 @@ class YOLO:
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source, return_outputs=False, **kwargs):
|
||||
def predict(self, source=None, stream=False, verbose=False, **kwargs):
|
||||
"""
|
||||
Visualize prediction.
|
||||
Perform prediction using the YOLO model.
|
||||
|
||||
Args:
|
||||
source (str): Accepts all source types accepted by yolo
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
verbose (bool): Whether to print verbose information or not. Defaults to False.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(dict): The prediction results.
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides["conf"] = 0.25
|
||||
@ -127,8 +134,8 @@ class YOLO:
|
||||
predictor = self.PredictorClass(overrides=overrides)
|
||||
|
||||
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
|
||||
predictor.setup(model=self.model, source=source, return_outputs=return_outputs)
|
||||
return predictor() if return_outputs else predictor.predict_cli()
|
||||
predictor.setup(model=self.model, source=source)
|
||||
return predictor(stream=stream, verbose=verbose)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
|
@ -27,13 +27,14 @@ Usage - formats:
|
||||
"""
|
||||
import platform
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadPilAndNumpy, LoadScreenshots, LoadStreams
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow
|
||||
@ -89,7 +90,6 @@ class BasePredictor:
|
||||
self.vid_path, self.vid_writer = None, None
|
||||
self.annotator = None
|
||||
self.data_path = None
|
||||
self.output = {}
|
||||
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
@ -99,29 +99,18 @@ class BasePredictor:
|
||||
def get_annotator(self, img):
|
||||
raise NotImplementedError("get_annotator function needs to be implemented")
|
||||
|
||||
def write_results(self, pred, batch, print_string):
|
||||
def write_results(self, results, batch, print_string):
|
||||
raise NotImplementedError("print_results function needs to be implemented")
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
return preds
|
||||
|
||||
def setup(self, source=None, model=None, return_outputs=False):
|
||||
def setup(self, source=None, model=None):
|
||||
# source
|
||||
source = str(source if source is not None else self.args.source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
|
||||
source, webcam, screenshot, from_img = self.check_source(source)
|
||||
# model
|
||||
device = select_device(self.args.device)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
stride, pt = model.stride, model.pt
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride) # check image size
|
||||
stride, pt = self.setup_model(model)
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride, min_dim=2) # check image size
|
||||
|
||||
# Dataloader
|
||||
bs = 1 # batch_size
|
||||
@ -131,7 +120,7 @@ class BasePredictor:
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(model.model, 'transforms', None),
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
vid_stride=self.args.vid_stride)
|
||||
bs = len(self.dataset)
|
||||
elif screenshot:
|
||||
@ -139,32 +128,47 @@ class BasePredictor:
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(model.model, 'transforms', None))
|
||||
transforms=getattr(self.model.model, 'transforms', None))
|
||||
elif from_img:
|
||||
self.dataset = LoadPilAndNumpy(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(self.model.model, 'transforms', None))
|
||||
else:
|
||||
self.dataset = LoadImages(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=pt,
|
||||
transforms=getattr(model.model, 'transforms', None),
|
||||
transforms=getattr(self.model.model, 'transforms', None),
|
||||
vid_stride=self.args.vid_stride)
|
||||
self.vid_path, self.vid_writer = [None] * bs, [None] * bs
|
||||
model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
|
||||
self.model.warmup(imgsz=(1 if pt or self.model.triton else bs, 3, *imgsz)) # warmup
|
||||
|
||||
self.model = model
|
||||
self.webcam = webcam
|
||||
self.screenshot = screenshot
|
||||
self.from_img = from_img
|
||||
self.imgsz = imgsz
|
||||
self.done_setup = True
|
||||
self.device = device
|
||||
self.return_outputs = return_outputs
|
||||
|
||||
return model
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, source=None, model=None, return_outputs=False):
|
||||
def __call__(self, source=None, model=None, verbose=False, stream=False):
|
||||
if stream:
|
||||
return self.stream_inference(source, model, verbose)
|
||||
else:
|
||||
return list(chain(*list(self.stream_inference(source, model, verbose)))) # merge list of Result into one
|
||||
|
||||
def predict_cli(self):
|
||||
# Method used for cli prediction. It uses always generator as outputs as not required by cli mode
|
||||
gen = self.stream_inference(verbose=True)
|
||||
for _ in gen: # running CLI inference without accumulating any outputs (do not modify)
|
||||
pass
|
||||
|
||||
def stream_inference(self, source=None, model=None, verbose=False):
|
||||
self.run_callbacks("on_predict_start")
|
||||
model = self.model if self.done_setup else self.setup(source, model, return_outputs)
|
||||
model.eval()
|
||||
if not self.done_setup:
|
||||
self.setup(source, model)
|
||||
self.seen, self.windows, self.dt = 0, [], (ops.Profile(), ops.Profile(), ops.Profile())
|
||||
for batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
@ -177,17 +181,17 @@ class BasePredictor:
|
||||
|
||||
# Inference
|
||||
with self.dt[1]:
|
||||
preds = model(im, augment=self.args.augment, visualize=visualize)
|
||||
preds = self.model(im, augment=self.args.augment, visualize=visualize)
|
||||
|
||||
# postprocess
|
||||
with self.dt[2]:
|
||||
preds = self.postprocess(preds, im, im0s)
|
||||
|
||||
results = self.postprocess(preds, im, im0s)
|
||||
for i in range(len(im)):
|
||||
if self.webcam:
|
||||
path, im0s = path[i], im0s[i]
|
||||
p = Path(path)
|
||||
s += self.write_results(i, preds, (p, im, im0s))
|
||||
p, im0 = (path[i], im0s[i]) if self.webcam or self.from_img else (path, im0s)
|
||||
p = Path(p)
|
||||
|
||||
if verbose or self.args.save or self.args.save_txt:
|
||||
s += self.write_results(i, results, (p, im, im0))
|
||||
|
||||
if self.args.show:
|
||||
self.show(p)
|
||||
@ -195,30 +199,50 @@ class BasePredictor:
|
||||
if self.args.save:
|
||||
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
|
||||
|
||||
if self.return_outputs:
|
||||
yield self.output
|
||||
self.output.clear()
|
||||
yield results
|
||||
|
||||
# Print time (inference-only)
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
if verbose:
|
||||
LOGGER.info(f"{s}{'' if len(preds) else '(no detections), '}{self.dt[1].dt * 1E3:.1f}ms")
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
|
||||
# Print results
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(
|
||||
f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape {(1, 3, *self.imgsz)}'
|
||||
% t)
|
||||
if verbose:
|
||||
t = tuple(x.t / self.seen * 1E3 for x in self.dt) # speeds per image
|
||||
LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms postprocess per image at shape '
|
||||
f'{(1, 3, *self.imgsz)}' % t)
|
||||
if self.args.save_txt or self.args.save:
|
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" if self.args.save_txt else ''
|
||||
s = f"\n{len(list(self.save_dir.glob('labels/*.txt')))} labels saved to {self.save_dir / 'labels'}" \
|
||||
if self.args.save_txt else ''
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def predict_cli(self, source=None, model=None, return_outputs=False):
|
||||
# as __call__ is a generator now so have to treat it like a generator
|
||||
for _ in (self.__call__(source, model, return_outputs)):
|
||||
pass
|
||||
def setup_model(self, model):
|
||||
device = select_device(self.args.device)
|
||||
model = model or self.args.model
|
||||
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
|
||||
model = AutoBackend(model, device=device, dnn=self.args.dnn, fp16=self.args.half)
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.model.eval()
|
||||
return model.stride, model.pt
|
||||
|
||||
def check_source(self, source):
|
||||
source = source if source is not None else self.args.source
|
||||
webcam, screenshot, from_img = False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
else:
|
||||
from_img = True
|
||||
return source, webcam, screenshot, from_img
|
||||
|
||||
def show(self, p):
|
||||
im0 = self.annotator.result()
|
||||
|
284
ultralytics/yolo/engine/results.py
Normal file
284
ultralytics/yolo/engine/results.py
Normal file
@ -0,0 +1,284 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.utils import LOGGER, ops
|
||||
|
||||
|
||||
class Results:
|
||||
"""
|
||||
A class for storing and manipulating inference results.
|
||||
|
||||
Args:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
|
||||
Attributes:
|
||||
boxes (Boxes, optional): A Boxes object containing the detection bounding boxes.
|
||||
masks (Masks, optional): A Masks object containing the detection masks.
|
||||
probs (torch.Tensor, optional): A tensor containing the detection class probabilities.
|
||||
orig_shape (tuple, optional): Original image size.
|
||||
"""
|
||||
|
||||
def __init__(self, boxes=None, masks=None, probs=None, orig_shape=None) -> None:
|
||||
self.boxes = Boxes(boxes, orig_shape) if boxes is not None else None # native size boxes
|
||||
self.masks = Masks(masks, orig_shape) if masks is not None else None # native size or imgsz masks
|
||||
self.probs = probs.softmax(0) if probs is not None else None
|
||||
self.orig_shape = orig_shape
|
||||
self.comp = ["boxes", "masks", "probs"]
|
||||
|
||||
def pandas(self):
|
||||
pass
|
||||
# TODO masks.pandas + boxes.pandas + cls.pandas
|
||||
|
||||
def __getitem__(self, idx):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item)[idx])
|
||||
return r
|
||||
|
||||
def cpu(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).cpu())
|
||||
return r
|
||||
|
||||
def numpy(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).numpy())
|
||||
return r
|
||||
|
||||
def cuda(self):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).cuda())
|
||||
return r
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r = Results(orig_shape=self.orig_shape)
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
setattr(r, item, getattr(self, item).to(*args, **kwargs))
|
||||
return r
|
||||
|
||||
def __len__(self):
|
||||
for item in self.comp:
|
||||
if getattr(self, item) is None:
|
||||
continue
|
||||
return len(getattr(self, item))
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
s = f'Ultralytics YOLO {self.__class__} instance\n' # string
|
||||
if self.boxes:
|
||||
s = s + self.boxes.__repr__() + '\n'
|
||||
if self.masks:
|
||||
s = s + self.masks.__repr__() + '\n'
|
||||
if self.probs:
|
||||
s = s + self.probs.__repr__()
|
||||
s += f'original size: {self.orig_shape}\n'
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class Boxes:
|
||||
"""
|
||||
A class for storing and manipulating detection boxes.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6). The last two columns should contain confidence and class values.
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
boxes (torch.Tensor) or (numpy.ndarray): A tensor or numpy array containing the detection boxes,
|
||||
with shape (num_boxes, 6).
|
||||
orig_shape (torch.Tensor) or (numpy.ndarray): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
xyxy (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format.
|
||||
conf (torch.Tensor) or (numpy.ndarray): The confidence values of the boxes.
|
||||
cls (torch.Tensor) or (numpy.ndarray): The class values of the boxes.
|
||||
xywh (torch.Tensor) or (numpy.ndarray): The boxes in xywh format.
|
||||
xyxyn (torch.Tensor) or (numpy.ndarray): The boxes in xyxy format normalized by original image size.
|
||||
xywhn (torch.Tensor) or (numpy.ndarray): The boxes in xywh format normalized by original image size.
|
||||
"""
|
||||
|
||||
def __init__(self, boxes, orig_shape) -> None:
|
||||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
assert boxes.shape[-1] == 6 # xyxy, conf, cls
|
||||
self.boxes = boxes
|
||||
self.orig_shape = torch.as_tensor(orig_shape, device=boxes.device) if isinstance(boxes, torch.Tensor) \
|
||||
else np.asarray(orig_shape)
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
return self.boxes[:, :4]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
return self.boxes[:, -2]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
return self.boxes[:, -1]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2) # maxsize 1 should suffice
|
||||
def xywh(self):
|
||||
return ops.xyxy2xywh(self.xyxy)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xyxyn(self):
|
||||
return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=2)
|
||||
def xywhn(self):
|
||||
return self.xywh / self.orig_shape[[1, 0, 1, 0]]
|
||||
|
||||
def cpu(self):
|
||||
boxes = self.boxes.cpu()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
boxes = self.boxes.numpy()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
boxes = self.boxes.cuda()
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
boxes = self.boxes.to(*args, **kwargs)
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
|
||||
def pandas(self):
|
||||
LOGGER.info('results.pandas() method not yet implemented')
|
||||
'''
|
||||
new = copy(self) # return copy
|
||||
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
||||
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
||||
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
||||
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
||||
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
||||
return new
|
||||
'''
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.boxes.shape
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.boxes)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.boxes)}\n" +
|
||||
f"shape: {self.boxes.shape}\n" + f"dtype: {self.boxes.dtype}")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
boxes = self.boxes[idx]
|
||||
return Boxes(boxes, self.orig_shape)
|
||||
|
||||
|
||||
class Masks:
|
||||
"""
|
||||
A class for storing and manipulating detection masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Attributes:
|
||||
masks (torch.Tensor): A tensor containing the detection masks, with shape (num_masks, height, width).
|
||||
orig_shape (tuple): Original image size, in the format (height, width).
|
||||
|
||||
Properties:
|
||||
segments (list): A list of segments which includes x,y,w,h,label,confidence, and mask of each detection masks.
|
||||
"""
|
||||
|
||||
def __init__(self, masks, orig_shape) -> None:
|
||||
self.masks = masks # N, h, w
|
||||
self.orig_shape = orig_shape
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def segments(self):
|
||||
return [
|
||||
ops.scale_segments(self.masks.shape[1:], x, self.orig_shape, normalize=True)
|
||||
for x in reversed(ops.masks2segments(self.masks))]
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.masks.shape
|
||||
|
||||
def cpu(self):
|
||||
masks = self.masks.cpu()
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def numpy(self):
|
||||
masks = self.masks.numpy()
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def cuda(self):
|
||||
masks = self.masks.cuda()
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
masks = self.masks.to(*args, **kwargs)
|
||||
return Masks(masks, self.orig_shape)
|
||||
|
||||
def __len__(self): # override len(results)
|
||||
return len(self.masks)
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def __repr__(self):
|
||||
return (f"Ultralytics YOLO {self.__class__} masks\n" + f"type: {type(self.masks)}\n" +
|
||||
f"shape: {self.masks.shape}\n" + f"dtype: {self.masks.dtype}")
|
||||
|
||||
def __getitem__(self, idx):
|
||||
masks = self.masks[idx]
|
||||
return Masks(masks, self.im_shape, self.orig_shape)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test examples
|
||||
results = Results(boxes=torch.randn((2, 6)), masks=torch.randn((2, 160, 160)), orig_shape=[640, 640])
|
||||
results = results.cuda()
|
||||
print("--cuda--pass--")
|
||||
results = results.cpu()
|
||||
print("--cpu--pass--")
|
||||
results = results.to("cuda:0")
|
||||
print("--to-cuda--pass--")
|
||||
results = results.to("cpu")
|
||||
print("--to-cpu--pass--")
|
||||
results = results.numpy()
|
||||
print("--numpy--pass--")
|
||||
# box = Boxes(boxes=torch.randn((2, 6)), orig_shape=[5, 5])
|
||||
# box = box.cuda()
|
||||
# box = box.cpu()
|
||||
# box = box.numpy()
|
||||
# for b in box:
|
||||
# print(b)
|
@ -30,7 +30,7 @@ from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr,
|
||||
yaml_save)
|
||||
from ultralytics.yolo.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.yolo.utils.checks import check_file, print_args
|
||||
from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args
|
||||
from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.yolo.utils.files import get_latest_run, increment_path
|
||||
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle, strip_optimizer
|
||||
@ -203,7 +203,9 @@ class BaseTrainer:
|
||||
self.set_model_attributes()
|
||||
if world_size > 1:
|
||||
self.model = DDP(self.model, device_ids=[rank])
|
||||
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs * 2)
|
||||
# Batch size
|
||||
if self.batch_size == -1:
|
||||
if RANK == -1: # single-GPU only, estimate best batch size
|
||||
|
@ -5,7 +5,6 @@ import inspect
|
||||
import logging.config
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -13,6 +12,7 @@ import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import git
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -134,10 +134,8 @@ def is_git_directory() -> bool:
|
||||
Returns:
|
||||
bool: True if the current working directory is inside a git repository, False otherwise.
|
||||
"""
|
||||
import git
|
||||
try:
|
||||
from git import Repo
|
||||
Repo(search_parent_directories=True)
|
||||
git.Repo(search_parent_directories=True)
|
||||
# subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True) # CLI alternative
|
||||
return True
|
||||
except git.exc.InvalidGitRepositoryError: # subprocess.CalledProcessError:
|
||||
@ -187,9 +185,10 @@ def get_git_root_dir():
|
||||
If the current file is not part of a git repository, returns None.
|
||||
"""
|
||||
try:
|
||||
output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
|
||||
return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # parent/.git
|
||||
except subprocess.CalledProcessError:
|
||||
# output = subprocess.run(["git", "rev-parse", "--git-dir"], capture_output=True, check=True)
|
||||
# return Path(output.stdout.strip().decode('utf-8')).parent.resolve() # CLI alternative
|
||||
return Path(git.Repo(search_parent_directories=True).working_tree_dir)
|
||||
except git.exc.InvalidGitRepositoryError: # (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
|
@ -15,20 +15,39 @@ from .metrics import box_iou
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
# YOLOv8 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
||||
"""
|
||||
YOLOv8 Profile class.
|
||||
Usage: as a decorator with @Profile() or as a context manager with 'with Profile():'
|
||||
"""
|
||||
|
||||
def __init__(self, t=0.0):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial time. Defaults to 0.0.
|
||||
"""
|
||||
self.t = t
|
||||
self.cuda = torch.cuda.is_available()
|
||||
|
||||
def __enter__(self):
|
||||
"""
|
||||
Start timing.
|
||||
"""
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
"""
|
||||
Stop timing.
|
||||
"""
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def time(self):
|
||||
"""
|
||||
Get current time.
|
||||
"""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize()
|
||||
return time.time()
|
||||
@ -48,15 +67,15 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
||||
|
||||
def segment2box(segment, width=640, height=640):
|
||||
"""
|
||||
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to
|
||||
(xyxy)
|
||||
Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
|
||||
Args:
|
||||
segment (torch.tensor): the segment label
|
||||
segment (torch.Tensor): the segment label
|
||||
width (int): the width of the image. Defaults to 640
|
||||
height (int): The height of the image. Defaults to 640
|
||||
|
||||
Returns:
|
||||
(np.array): the minimum and maximum x and y values of the segment.
|
||||
(np.ndarray): the minimum and maximum x and y values of the segment.
|
||||
"""
|
||||
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
||||
x, y = segment.T # segment xy
|
||||
@ -67,15 +86,18 @@ def segment2box(segment, width=640, height=640):
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||
"""
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in (img1_shape) to the shape of a different image (img0_shape).
|
||||
Rescales bounding boxes (in the format of xyxy) from the shape of the image they were originally specified in
|
||||
(img1_shape) to the shape of a different image (img0_shape).
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the bounding boxes are for, in the format of (height, width).
|
||||
boxes (torch.tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
||||
boxes (torch.Tensor): the bounding boxes of the objects in the image, in the format of (x1, y1, x2, y2)
|
||||
img0_shape (tuple): the shape of the target image, in the format of (height, width).
|
||||
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be calculated based on the size difference between the two images.
|
||||
ratio_pad (tuple): a tuple of (ratio, pad) for scaling the boxes. If not provided, the ratio and pad will be
|
||||
calculated based on the size difference between the two images.
|
||||
|
||||
Returns:
|
||||
boxes (torch.tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
boxes (torch.Tensor): The scaled bounding boxes, in the format of (x1, y1, x2, y2)
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
@ -92,7 +114,16 @@ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
||||
|
||||
|
||||
def make_divisible(x, divisor):
|
||||
# Returns nearest x divisible by divisor
|
||||
"""
|
||||
Returns the nearest number that is divisible by the given divisor.
|
||||
|
||||
Args:
|
||||
x (int): The number to make divisible.
|
||||
divisor (int or torch.Tensor): The divisor.
|
||||
|
||||
Returns:
|
||||
int: The nearest number divisible by the divisor.
|
||||
"""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
@ -232,7 +263,7 @@ def clip_boxes(boxes, shape):
|
||||
shape
|
||||
|
||||
Args:
|
||||
boxes (torch.tensor): the bounding boxes to clip
|
||||
boxes (torch.Tensor): the bounding boxes to clip
|
||||
shape (tuple): the shape of the image
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
@ -246,7 +277,19 @@ def clip_boxes(boxes, shape):
|
||||
|
||||
|
||||
def clip_coords(boxes, shape):
|
||||
# Clip bounding xyxy bounding boxes to image shape (height, width)
|
||||
"""
|
||||
Clip bounding xyxy bounding boxes to image shape (height, width).
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor or numpy.ndarray): Bounding boxes to be clipped.
|
||||
shape (tuple): The shape of the image. (height, width)
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Note:
|
||||
The input `boxes` is modified in-place, there is no return value.
|
||||
"""
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
boxes[:, 0].clamp_(0, shape[1]) # x1
|
||||
boxes[:, 1].clamp_(0, shape[0]) # y1
|
||||
@ -263,12 +306,12 @@ def scale_image(im1_shape, masks, im0_shape, ratio_pad=None):
|
||||
|
||||
Args:
|
||||
im1_shape (tuple): model input shape, [h, w]
|
||||
masks (torch.tensor): [h, w, num]
|
||||
masks (torch.Tensor): [h, w, num]
|
||||
im0_shape (tuple): the original image shape
|
||||
ratio_pad (tuple): the ratio of the padding to the original image.
|
||||
|
||||
Returns:
|
||||
masks (torch.tensor): The masks that are being returned.
|
||||
masks (torch.Tensor): The masks that are being returned.
|
||||
"""
|
||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||
if ratio_pad is None: # calculate from im0_shape
|
||||
@ -297,9 +340,9 @@ def xyxy2xywh(x):
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format.
|
||||
|
||||
Args:
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
||||
@ -311,12 +354,13 @@ def xyxy2xywh(x):
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
|
||||
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x, y, width, height) format.
|
||||
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x, y, width, height) format.
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
@ -337,7 +381,8 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
padw (int): Padding width. Defaults to 0
|
||||
padh (int): Padding height. Defaults to 0
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
y (np.ndarray) or (torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
||||
@ -349,16 +394,17 @@ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
||||
|
||||
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, width and height are normalized to image dimensions
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format.
|
||||
x, y, width and height are normalized to image dimensions
|
||||
|
||||
Args:
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor containing the bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
x (np.ndarray) or (torch.Tensor): The input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
w (int): The width of the image. Defaults to 640
|
||||
h (int): The height of the image. Defaults to 640
|
||||
clip (bool): If True, the boxes will be clipped to the image boundaries. Defaults to False
|
||||
eps (float): The minimum value of the box's width and height. Defaults to 0.0
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in (x, y, width, height, normalized) format
|
||||
"""
|
||||
if clip:
|
||||
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
||||
@ -375,13 +421,13 @@ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
||||
Convert normalized coordinates to pixel coordinates of shape (n,2)
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor of normalized bounding box coordinates
|
||||
w (int): The width of the image. Defaults to 640
|
||||
h (int): The height of the image. Defaults to 640
|
||||
padw (int): The width of the padding. Defaults to 0
|
||||
padh (int): The height of the padding. Defaults to 0
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||
y (np.ndarray) or (torch.Tensor): The x and y coordinates of the top left corner of the bounding box
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = w * x[..., 0] + padw # top left x
|
||||
@ -394,9 +440,9 @@ def xywh2ltwh(x):
|
||||
Convert the bounding box format from [x, y, w, h] to [x1, y1, w, h], where x1, y1 are the top-left coordinates.
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor with the bounding box coordinates in the xywh format
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
@ -409,9 +455,9 @@ def xyxy2ltwh(x):
|
||||
Convert nx4 bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h], where xy1=top-left, xy2=bottom-right
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
x (np.ndarray) or (torch.Tensor): The input tensor with the bounding boxes coordinates in the xyxy format
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
y (np.ndarray) or (torch.Tensor): The bounding box coordinates in the xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] - x[:, 0] # width
|
||||
@ -424,7 +470,7 @@ def ltwh2xywh(x):
|
||||
Convert nx4 boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center
|
||||
|
||||
Args:
|
||||
x (torch.tensor): the input tensor
|
||||
x (torch.Tensor): the input tensor
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 0] = x[:, 0] + x[:, 2] / 2 # center x
|
||||
@ -437,10 +483,10 @@ def ltwh2xyxy(x):
|
||||
It converts the bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
|
||||
Args:
|
||||
x (numpy.ndarray) or (torch.Tensor): the input image
|
||||
x (np.ndarray) or (torch.Tensor): the input image
|
||||
|
||||
Returns:
|
||||
y (numpy.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
y (np.ndarray) or (torch.Tensor): the xyxy coordinates of the bounding boxes.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[:, 2] = x[:, 2] + x[:, 0] # width
|
||||
@ -456,7 +502,7 @@ def segments2boxes(segments):
|
||||
segments (list): list of segments, each segment is a list of points, each point is a list of x, y coordinates
|
||||
|
||||
Returns:
|
||||
(np.array): the xywh coordinates of the bounding boxes.
|
||||
(np.ndarray): the xywh coordinates of the bounding boxes.
|
||||
"""
|
||||
boxes = []
|
||||
for s in segments:
|
||||
@ -467,7 +513,7 @@ def segments2boxes(segments):
|
||||
|
||||
def resample_segments(segments, n=1000):
|
||||
"""
|
||||
It takes a list of segments (n,2) and returns a list of segments (n,2) where each segment has been up-sampled to n points
|
||||
Inputs a list of segments (n,2) and returns a list of segments (n,2) up-sampled to n points each.
|
||||
|
||||
Args:
|
||||
segments (list): a list of (n,2) arrays, where n is the number of points in the segment.
|
||||
@ -489,11 +535,11 @@ def crop_mask(masks, boxes):
|
||||
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box
|
||||
|
||||
Args:
|
||||
masks (torch.tensor): [h, w, n] tensor of masks
|
||||
boxes (torch.tensor): [n, 4] tensor of bbox coordinates in relative point form
|
||||
masks (torch.Tensor): [h, w, n] tensor of masks
|
||||
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
|
||||
|
||||
Returns:
|
||||
(torch.tensor): The masks are being cropped to the bounding box.
|
||||
(torch.Tensor): The masks are being cropped to the bounding box.
|
||||
"""
|
||||
n, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(1,1,n)
|
||||
@ -509,13 +555,13 @@ def process_mask_upsample(protos, masks_in, bboxes, shape):
|
||||
quality but is slower.
|
||||
|
||||
Args:
|
||||
protos (torch.tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.tensor): [n, 4], n is number of masks after nms
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
(torch.tensor): The upsampled masks.
|
||||
(torch.Tensor): The upsampled masks.
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
@ -530,13 +576,13 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False):
|
||||
downsampled quality of mask
|
||||
|
||||
Args:
|
||||
protos (torch.tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.tensor): [n, 4], n is number of masks after nms
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
(torch.tensor): The processed masks.
|
||||
(torch.Tensor): The processed masks.
|
||||
"""
|
||||
|
||||
c, mh, mw = protos.shape # CHW
|
||||
@ -560,13 +606,13 @@ def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
It takes the output of the mask head, and crops it after upsampling to the bounding boxes.
|
||||
|
||||
Args:
|
||||
protos (torch.tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.tensor): [n, 4], n is number of masks after nms
|
||||
protos (torch.Tensor): [mask_dim, mask_h, mask_w]
|
||||
masks_in (torch.Tensor): [n, mask_dim], n is number of masks after nms
|
||||
bboxes (torch.Tensor): [n, 4], n is number of masks after nms
|
||||
shape (tuple): the size of the input image (h,w)
|
||||
|
||||
Returns:
|
||||
masks (torch.tensor): The returned masks with dimensions [h, w, n]
|
||||
masks (torch.Tensor): The returned masks with dimensions [h, w, n]
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).sigmoid().view(-1, mh, mw)
|
||||
@ -587,13 +633,13 @@ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=F
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): The shape of the image that the segments are from.
|
||||
segments (torch.tensor): the segments to be scaled
|
||||
segments (torch.Tensor): the segments to be scaled
|
||||
img0_shape (tuple): the shape of the image that the segmentation is being applied to
|
||||
ratio_pad (tuple): the ratio of the image size to the padded image size.
|
||||
normalize (bool): If True, the coordinates will be normalized to the range [0, 1]. Defaults to False
|
||||
|
||||
Returns:
|
||||
segments (torch.tensor): the segmented image.
|
||||
segments (torch.Tensor): the segmented image.
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
@ -617,7 +663,7 @@ def masks2segments(masks, strategy='largest'):
|
||||
It takes a list of masks(n,h,w) and returns a list of segments(n,xy)
|
||||
|
||||
Args:
|
||||
masks (torch.tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||
masks (torch.Tensor): the output of the model, which is a tensor of shape (batch_size, 160, 160)
|
||||
strategy (str): 'concat' or 'largest'. Defaults to largest
|
||||
|
||||
Returns:
|
||||
|
@ -4,8 +4,8 @@ import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.plotting import Annotator
|
||||
|
||||
|
||||
@ -15,20 +15,27 @@ class ClassificationPredictor(BasePredictor):
|
||||
return Annotator(img, example=str(self.model.names), pil=True)
|
||||
|
||||
def preprocess(self, img):
|
||||
img = torch.Tensor(img).to(self.model.device)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.Tensor(img)).to(self.model.device)
|
||||
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
return img
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape
|
||||
results.append(Results(probs=pred.softmax(0), orig_shape=shape[:2]))
|
||||
return results
|
||||
|
||||
def write_results(self, idx, results, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam: # batch_size >= 1
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.cound
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
|
||||
@ -38,9 +45,10 @@ class ClassificationPredictor(BasePredictor):
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
prob = preds[idx].softmax(0)
|
||||
if self.return_outputs:
|
||||
self.output["prob"] = prob.cpu().numpy()
|
||||
result = results[idx]
|
||||
if len(result) == 0:
|
||||
return log_string
|
||||
prob = result.probs
|
||||
# Print results
|
||||
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
|
||||
@ -59,7 +67,6 @@ class ClassificationPredictor(BasePredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
predictor = ClassificationPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
@ -56,6 +56,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.endswith(".yaml"):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
|
@ -4,8 +4,8 @@ import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.predictor import BasePredictor
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
||||
|
||||
|
||||
@ -27,58 +27,53 @@ class DetectionPredictor(BasePredictor):
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det)
|
||||
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||
shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
results.append(Results(boxes=pred, orig_shape=shape[:2]))
|
||||
return results
|
||||
|
||||
return preds
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
def write_results(self, idx, results, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
im0 = im0.copy()
|
||||
if self.webcam: # batch_size >= 1
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
frame = getattr(self.dataset, 'frame', 0)
|
||||
|
||||
self.data_path = p
|
||||
# save_path = str(self.save_dir / p.name) # im.jpg
|
||||
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
det = preds[idx]
|
||||
det = results[idx].boxes # TODO: make boxes inherit from tensors
|
||||
if len(det) == 0:
|
||||
return log_string
|
||||
for c in det[:, 5].unique():
|
||||
n = (det[:, 5] == c).sum() # detections per class
|
||||
for c in det.cls.unique():
|
||||
n = (det.cls == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||
|
||||
if self.return_outputs:
|
||||
self.output["det"] = det.cpu().numpy()
|
||||
|
||||
# write
|
||||
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
|
||||
for *xyxy, conf, cls in reversed(det):
|
||||
for d in reversed(det):
|
||||
cls, conf = d.cls.squeeze(), d.conf.squeeze()
|
||||
if self.args.save_txt: # Write to file
|
||||
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
|
||||
line = (cls, *xywh, conf) if self.args.save_conf else (cls, *xywh) # label format
|
||||
line = (cls, *(d.xywhn.view(-1).tolist()), conf) \
|
||||
if self.args.save_conf else (cls, *(d.xywhn.view(-1).tolist())) # label format
|
||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
imc = im0.copy()
|
||||
save_one_box(xyxy,
|
||||
save_one_box(d.xyxy,
|
||||
imc,
|
||||
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',
|
||||
BGR=True)
|
||||
@ -89,7 +84,6 @@ class DetectionPredictor(BasePredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.pt"
|
||||
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
predictor = DetectionPredictor(cfg)
|
||||
predictor.predict_cli()
|
||||
|
@ -3,8 +3,8 @@
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.engine.results import Results
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, ROOT, ops
|
||||
from ultralytics.yolo.utils.checks import check_imgsz
|
||||
from ultralytics.yolo.utils.plotting import colors, save_one_box
|
||||
from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
|
||||
@ -12,7 +12,6 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
masks = []
|
||||
# TODO: filter by classes
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
@ -20,27 +19,29 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
agnostic=self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nm=32)
|
||||
results = []
|
||||
proto = preds[1][-1]
|
||||
for i, pred in enumerate(p):
|
||||
shape = orig_img[i].shape if self.webcam else orig_img.shape
|
||||
shape = orig_img[i].shape if isinstance(orig_img, list) else orig_img.shape
|
||||
if not len(pred):
|
||||
results.append(Results(boxes=pred[:, :6], orig_shape=shape[:2])) # save empty boxes
|
||||
continue
|
||||
if self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
masks.append(ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2])) # HWC
|
||||
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], shape[:2]) # HWC
|
||||
else:
|
||||
masks.append(ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)) # HWC
|
||||
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
|
||||
results.append(Results(boxes=pred[:, :6], masks=masks, orig_shape=shape[:2]))
|
||||
return results
|
||||
|
||||
return (p, masks)
|
||||
|
||||
def write_results(self, idx, preds, batch):
|
||||
def write_results(self, idx, results, batch):
|
||||
p, im, im0 = batch
|
||||
log_string = ""
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
self.seen += 1
|
||||
if self.webcam: # batch_size >= 1
|
||||
if self.webcam or self.from_img: # batch_size >= 1
|
||||
log_string += f'{idx}: '
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
@ -51,54 +52,48 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
preds, masks = preds
|
||||
det = preds[idx]
|
||||
if len(det) == 0:
|
||||
result = results[idx]
|
||||
if len(result) == 0:
|
||||
return log_string
|
||||
# Segments
|
||||
mask = masks[idx]
|
||||
if self.args.save_txt or self.return_outputs:
|
||||
shape = im0.shape if self.args.retina_masks else im.shape[2:]
|
||||
segments = [
|
||||
ops.scale_segments(shape, x, im0.shape, normalize=False) for x in reversed(ops.masks2segments(mask))]
|
||||
det, mask = result.boxes, result.masks # getting tensors TODO: mask mask,box inherit for tensor
|
||||
|
||||
# Print results
|
||||
for c in det[:, 5].unique():
|
||||
n = (det[:, 5] == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
||||
for c in det.cls.unique():
|
||||
n = (det.cls == c).sum() # detections per class
|
||||
log_string += f"{n} {self.model.names[int(c)]}{'s' * (n > 1)}, "
|
||||
|
||||
# Mask plotting
|
||||
self.annotator.masks(
|
||||
mask,
|
||||
colors=[colors(x, True) for x in det[:, 5]],
|
||||
mask.masks,
|
||||
colors=[colors(x, True) for x in det.cls],
|
||||
im_gpu=torch.as_tensor(im0, dtype=torch.float16).to(self.device).permute(2, 0, 1).flip(0).contiguous() /
|
||||
255 if self.args.retina_masks else im[idx])
|
||||
|
||||
det = reversed(det[:, :6])
|
||||
if self.return_outputs:
|
||||
self.output["det"] = det.cpu().numpy()
|
||||
self.output["segment"] = segments
|
||||
# Segments
|
||||
if self.args.save_txt:
|
||||
segments = mask.segments
|
||||
|
||||
# Write results
|
||||
for j, (*xyxy, conf, cls) in enumerate(det):
|
||||
for j, d in enumerate(reversed(det)):
|
||||
cls, conf = d.cls.squeeze(), d.conf.squeeze()
|
||||
if self.args.save_txt: # Write to file
|
||||
seg = segments[j].copy()
|
||||
seg[:, 0] /= shape[1] # width
|
||||
seg[:, 1] /= shape[0] # height
|
||||
seg = seg.reshape(-1) # (n,2) to (n*2)
|
||||
line = (cls, *seg, conf) if self.args.save_conf else (cls, *seg) # label format
|
||||
with open(f'{self.txt_path}.txt', 'a') as f:
|
||||
f.write(('%g ' * len(line)).rstrip() % line + '\n')
|
||||
|
||||
if self.args.save or self.args.save_crop or self.args.show:
|
||||
if self.args.save or self.args.save_crop or self.args.show: # Add bbox to image
|
||||
c = int(cls) # integer class
|
||||
label = None if self.args.hide_labels else (
|
||||
self.model.names[c] if self.args.hide_conf else f'{self.model.names[c]} {conf:.2f}')
|
||||
self.annotator.box_label(xyxy, label, color=colors(c, True))
|
||||
# annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
|
||||
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
|
||||
if self.args.save_crop:
|
||||
imc = im0.copy()
|
||||
save_one_box(xyxy, imc, file=self.save_dir / 'crops' / self.model.names[c] / f'{p.stem}.jpg', BGR=True)
|
||||
save_one_box(d.xyxy,
|
||||
imc,
|
||||
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',
|
||||
BGR=True)
|
||||
|
||||
return log_string
|
||||
|
||||
@ -106,7 +101,6 @@ class SegmentationPredictor(DetectionPredictor):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def predict(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-seg.pt"
|
||||
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
|
||||
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
|
||||
|
||||
predictor = SegmentationPredictor(cfg)
|
||||
|
Reference in New Issue
Block a user