General refactoring and improvements (#373)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent ac628c0d3e
commit 583eac0e80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,6 +13,7 @@ on:
jobs: jobs:
CLA: CLA:
if: github.repository == 'ultralytics/ultralytics'
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: "CLA Assistant" - name: "CLA Assistant"

@ -7,7 +7,7 @@ import requests
from ultralytics import __version__ from ultralytics import __version__
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, is_colab, threaded from ultralytics.yolo.utils import is_colab, threaded
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local' AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'

@ -32,21 +32,21 @@ class AutoBackend(nn.Module):
fp16 (bool): If True, use half precision. Default: False fp16 (bool): If True, use half precision. Default: False
fuse (bool): Whether to fuse the model or not. Default: True fuse (bool): Whether to fuse the model or not. Default: True
Supported formats and their usage: Supported formats and their naming conventions:
Platform | Weights Format | Format | Suffix |
-----------------------|------------------ |-----------------------|------------------|
PyTorch | *.pt | PyTorch | *.pt |
TorchScript | *.torchscript | TorchScript | *.torchscript |
ONNX Runtime | *.onnx | ONNX Runtime | *.onnx |
ONNX OpenCV DNN | *.onnx --dnn | ONNX OpenCV DNN | *.onnx --dnn |
OpenVINO | *.xml | OpenVINO | *.xml |
CoreML | *.mlmodel | CoreML | *.mlmodel |
TensorRT | *.engine | TensorRT | *.engine |
TensorFlow SavedModel | *_saved_model | TensorFlow SavedModel | *_saved_model |
TensorFlow GraphDef | *.pb | TensorFlow GraphDef | *.pb |
TensorFlow Lite | *.tflite | TensorFlow Lite | *.tflite |
TensorFlow Edge TPU | *_edgetpu.tflite | TensorFlow Edge TPU | *_edgetpu.tflite |
PaddlePaddle | *_paddle_model | PaddlePaddle | *_paddle_model |
""" """
super().__init__() super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights) w = str(weights[0] if isinstance(weights, list) else weights)
@ -357,7 +357,7 @@ class AutoBackend(nn.Module):
This function takes a path to a model file and returns the model type This function takes a path to a model file and returns the model type
Args: Args:
p: path to the model file. Defaults to path/to/model.pt p: path to the model file. Defaults to path/to/model.pt
""" """
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle] # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
@ -374,12 +374,11 @@ class AutoBackend(nn.Module):
@staticmethod @staticmethod
def _load_metadata(f=Path('path/to/meta.yaml')): def _load_metadata(f=Path('path/to/meta.yaml')):
""" """
> Loads the metadata from a yaml file Loads the metadata from a yaml file
Args: Args:
f: The path to the metadata file. f: The path to the metadata file.
""" """
from ultralytics.yolo.utils.files import yaml_load
# Load metadata from meta.yaml if it exists # Load metadata from meta.yaml if it exists
if f.exists(): if f.exists():

@ -5,28 +5,11 @@ Common modules
import math import math
import warnings import warnings
from copy import copy
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import requests
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image, ImageOps
from torch.cuda import amp
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
# from utils.plots import feature_visualization TODO from ultralytics.yolo.utils.tal import dist2bbox, make_anchors
def autopad(k, p=None, d=1): # kernel, padding, dilation def autopad(k, p=None, d=1): # kernel, padding, dilation
@ -365,216 +348,6 @@ class Concat(nn.Module):
return torch.cat(x, self.d) return torch.cat(x, self.d)
class AutoShape(nn.Module):
# YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
agnostic = False # NMS class-agnostic
multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference
def __init__(self, model, verbose=True):
super().__init__()
if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
self.model = model.eval()
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.inplace = False # Detect.inplace=False for safe multithread inference
m.export = True # do not output loss values
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
@smart_inference_mode()
def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
dt = (Profile(), Profile(), Profile())
with dt[0]:
if isinstance(size, int): # expand
size = (size, size)
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
# Pre-process
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(ims):
f = f'image{i}' # filename
if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
im = np.asarray(ImageOps.exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = max(size) / max(s) # gain
shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
with amp.autocast(autocast):
# Inference
with dt[1]:
y = self.model(x, augment=augment) # forward
# Post-process
with dt[2]:
y = non_max_suppression(y if self.dmb else y[0],
self.conf,
self.iou,
self.classes,
self.agnostic,
self.multi_label,
max_det=self.max_det) # NMS
for i in range(n):
scale_boxes(shape1, y[i][:, :4], shape0[i])
return Detections(ims, y, files, dt, self.names, x.shape)
class Detections:
# YOLOv8 detections class for inference results
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
super().__init__()
d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
self.ims = ims # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.files = files # image filenames
self.times = times # profiling times
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size)
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
self.s = tuple(shape) # inference BCHW shape
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
s, crops = '', []
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
if pred.shape[0]:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
s = s.rstrip(', ')
if show or save or render or crop:
annotator = Annotator(im, example=str(self.names))
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
label = f'{self.names[int(cls)]} {conf:.2f}'
if crop:
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
crops.append({
'box': box,
'conf': conf,
'cls': cls,
'label': label,
'im': save_one_box(box, im, file=file, save=save)})
else: # all others
annotator.box_label(box, label if labels else '', color=colors(cls))
im = annotator.im
else:
s += '(no detections)'
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show:
im.show(self.files[i]) # show
if save:
f = self.files[i]
im.save(save_dir / f) # save
if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
if render:
self.ims[i] = np.asarray(im)
if pprint:
s = s.lstrip('\n')
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
if crop:
if save:
LOGGER.info(f'Saved results to {save_dir}\n')
return crops
def show(self, labels=True):
self._run(show=True, labels=labels) # show results
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
self._run(save=True, labels=labels, save_dir=save_dir) # save results
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
def render(self, labels=True):
self._run(render=True, labels=labels) # render results
return self.ims
def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new
def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
r = range(self.n) # iterable
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x:
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x
def print(self):
LOGGER.info(self.__str__())
def __len__(self): # override len(results)
return self.n
def __str__(self): # override print(results)
return self._run(pprint=True) # print results
def __repr__(self):
return f'YOLOv8 {self.__class__} instance\n' + self.__str__()
class Proto(nn.Module): class Proto(nn.Module):
# YOLOv8 mask Proto module for segmentation models # YOLOv8 mask Proto module for segmentation models
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks

@ -0,0 +1,237 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
"""
Common modules
"""
from copy import copy
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
import requests
import torch
import torch.nn as nn
from PIL import Image, ImageOps
from torch.cuda import amp
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.yolo.data.augment import LetterBox
from ultralytics.yolo.utils import LOGGER, colorstr
from ultralytics.yolo.utils.files import increment_path
from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
class AutoShape(nn.Module):
# YOLOv8 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
agnostic = False # NMS class-agnostic
multi_label = False # NMS multiple labels per box
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
max_det = 1000 # maximum number of detections per image
amp = False # Automatic Mixed Precision (AMP) inference
def __init__(self, model, verbose=True):
super().__init__()
if verbose:
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
self.pt = not self.dmb or model.pt # PyTorch model
self.model = model.eval()
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.inplace = False # Detect.inplace=False for safe multithread inference
m.export = True # do not output loss values
def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self
@smart_inference_mode()
def forward(self, ims, size=640, augment=False, profile=False):
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
# URI: = 'https://ultralytics.com/images/zidane.jpg'
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
# numpy: = np.zeros((640,1280,3)) # HWC
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
dt = (Profile(), Profile(), Profile())
with dt[0]:
if isinstance(size, int): # expand
size = (size, size)
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
if isinstance(ims, torch.Tensor): # torch
with amp.autocast(autocast):
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
# Pre-process
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
for i, im in enumerate(ims):
f = f'image{i}' # filename
if isinstance(im, (str, Path)): # filename or uri
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
im = np.asarray(ImageOps.exif_transpose(im))
elif isinstance(im, Image.Image): # PIL Image
im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
files.append(Path(f).with_suffix('.jpg').name)
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = max(size) / max(s) # gain
shape1.append([y * g for y in s])
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
with amp.autocast(autocast):
# Inference
with dt[1]:
y = self.model(x, augment=augment) # forward
# Post-process
with dt[2]:
y = non_max_suppression(y if self.dmb else y[0],
self.conf,
self.iou,
self.classes,
self.agnostic,
self.multi_label,
max_det=self.max_det) # NMS
for i in range(n):
scale_boxes(shape1, y[i][:, :4], shape0[i])
return Detections(ims, y, files, dt, self.names, x.shape)
class Detections:
# YOLOv8 detections class for inference results
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
super().__init__()
d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
self.ims = ims # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.files = files # image filenames
self.times = times # profiling times
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred) # number of images (batch size)
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
self.s = tuple(shape) # inference BCHW shape
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
s, crops = '', []
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
if pred.shape[0]:
for c in pred[:, -1].unique():
n = (pred[:, -1] == c).sum() # detections per class
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
s = s.rstrip(', ')
if show or save or render or crop:
annotator = Annotator(im, example=str(self.names))
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
label = f'{self.names[int(cls)]} {conf:.2f}'
if crop:
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
crops.append({
'box': box,
'conf': conf,
'cls': cls,
'label': label,
'im': save_one_box(box, im, file=file, save=save)})
else: # all others
annotator.box_label(box, label if labels else '', color=colors(cls))
im = annotator.im
else:
s += '(no detections)'
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
if show:
im.show(self.files[i]) # show
if save:
f = self.files[i]
im.save(save_dir / f) # save
if i == self.n - 1:
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
if render:
self.ims[i] = np.asarray(im)
if pprint:
s = s.lstrip('\n')
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
if crop:
if save:
LOGGER.info(f'Saved results to {save_dir}\n')
return crops
def show(self, labels=True):
self._run(show=True, labels=labels) # show results
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
self._run(save=True, labels=labels, save_dir=save_dir) # save results
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
def render(self, labels=True):
self._run(render=True, labels=labels) # render results
return self.ims
def pandas(self):
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
new = copy(self) # return copy
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
return new
def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
r = range(self.n) # iterable
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
# for d in x:
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
# setattr(d, k, getattr(d, k)[0]) # pop out of list
return x
def print(self):
LOGGER.info(self.__str__())
def __len__(self): # override len(results)
return self.n
def __str__(self): # override print(results)
return self._run(pprint=True) # print results
def __repr__(self):
return f'YOLOv8 {self.__class__} instance\n' + self.__str__()
print('works')

@ -57,7 +57,7 @@ class BaseModel(nn.Module):
x = m(x) # run x = m(x) # run
y.append(x if m.i in self.save else None) # save output y.append(x if m.i in self.save else None) # save output
if visualize: if visualize:
pass LOGGER.info('visualize feature not yet supported')
# TODO: feature_visualization(x, m.type, m.i, save_dir=visualize) # TODO: feature_visualization(x, m.type, m.i, save_dir=visualize)
return x return x
@ -106,8 +106,8 @@ class BaseModel(nn.Module):
Prints model information Prints model information
Args: Args:
verbose (bool): if True, prints out the model information. Defaults to False verbose (bool): if True, prints out the model information. Defaults to False
imgsz (int): the size of the image that the model will be trained on. Defaults to 640 imgsz (int): the size of the image that the model will be trained on. Defaults to 640
""" """
model_info(self, verbose, imgsz) model_info(self, verbose, imgsz)
@ -117,10 +117,10 @@ class BaseModel(nn.Module):
parameters or registered buffers parameters or registered buffers
Args: Args:
fn: the function to apply to the model fn: the function to apply to the model
Returns: Returns:
A model that is a Detect() object. A model that is a Detect() object.
""" """
self = super()._apply(fn) self = super()._apply(fn)
m = self.model[-1] # Detect() m = self.model[-1] # Detect()
@ -135,7 +135,7 @@ class BaseModel(nn.Module):
This function loads the weights of the model from a file This function loads the weights of the model from a file
Args: Args:
weights (str): The weights to load into the model. weights (str): The weights to load into the model.
""" """
# Force all tasks to implement this function # Force all tasks to implement this function
raise NotImplementedError("This function needs to be implemented by derived classes!") raise NotImplementedError("This function needs to be implemented by derived classes!")

@ -32,7 +32,7 @@ class YOLO:
def __init__(self, model='yolov8n.yaml', type="v8") -> None: def __init__(self, model='yolov8n.yaml', type="v8") -> None:
""" """
> Initializes the YOLO object. Initializes the YOLO object.
Args: Args:
model (str, Path): model to load or create model (str, Path): model to load or create
@ -59,7 +59,7 @@ class YOLO:
def _new(self, cfg: str, verbose=True): def _new(self, cfg: str, verbose=True):
""" """
> Initializes a new model and infers the task type from the model definitions. Initializes a new model and infers the task type from the model definitions.
Args: Args:
cfg (str): model configuration file cfg (str): model configuration file
@ -75,7 +75,7 @@ class YOLO:
def _load(self, weights: str): def _load(self, weights: str):
""" """
> Initializes a new model and infers the task type from the model head. Initializes a new model and infers the task type from the model head.
Args: Args:
weights (str): model checkpoint to be loaded weights (str): model checkpoint to be loaded
@ -90,7 +90,7 @@ class YOLO:
def reset(self): def reset(self):
""" """
> Resets the model modules. Resets the model modules.
""" """
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, 'reset_parameters'): if hasattr(m, 'reset_parameters'):
@ -100,7 +100,7 @@ class YOLO:
def info(self, verbose=False): def info(self, verbose=False):
""" """
> Logs model info. Logs model info.
Args: Args:
verbose (bool): Controls verbosity. verbose (bool): Controls verbosity.
@ -133,7 +133,7 @@ class YOLO:
@smart_inference_mode() @smart_inference_mode()
def val(self, data=None, **kwargs): def val(self, data=None, **kwargs):
""" """
> Validate a model on a given dataset . Validate a model on a given dataset .
Args: Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo data (str): The dataset to validate on. Accepts all formats accepted by yolo
@ -152,7 +152,7 @@ class YOLO:
@smart_inference_mode() @smart_inference_mode()
def export(self, **kwargs): def export(self, **kwargs):
""" """
> Export model. Export model.
Args: Args:
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
@ -168,7 +168,7 @@ class YOLO:
def train(self, **kwargs): def train(self, **kwargs):
""" """
> Trains the model on a given dataset. Trains the model on a given dataset.
Args: Args:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section. **kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
@ -197,7 +197,7 @@ class YOLO:
def to(self, device): def to(self, device):
""" """
> Sends the model to the given device. Sends the model to the given device.
Args: Args:
device (str): device device (str): device

@ -89,7 +89,7 @@ class BasePredictor:
self.vid_path, self.vid_writer = None, None self.vid_path, self.vid_writer = None, None
self.annotator = None self.annotator = None
self.data_path = None self.data_path = None
self.output = dict() self.output = {}
self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
@ -216,7 +216,7 @@ class BasePredictor:
self.run_callbacks("on_predict_end") self.run_callbacks("on_predict_end")
def predict_cli(self, source=None, model=None, return_outputs=False): def predict_cli(self, source=None, model=None, return_outputs=False):
# as __call__ is a genertor now so have to treat it like a genertor # as __call__ is a generator now so have to treat it like a generator
for _ in (self.__call__(source, model, return_outputs)): for _ in (self.__call__(source, model, return_outputs)):
pass pass

@ -40,7 +40,7 @@ class BaseTrainer:
""" """
BaseTrainer BaseTrainer
> A base class for creating trainers. A base class for creating trainers.
Attributes: Attributes:
args (OmegaConf): Configuration for the trainer. args (OmegaConf): Configuration for the trainer.
@ -75,7 +75,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None): def __init__(self, config=DEFAULT_CONFIG, overrides=None):
""" """
> Initializes the BaseTrainer class. Initializes the BaseTrainer class.
Args: Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -149,13 +149,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback): def add_callback(self, event: str, callback):
""" """
> Appends the given callback. Appends the given callback.
""" """
self.callbacks[event].append(callback) self.callbacks[event].append(callback)
def set_callback(self, event: str, callback): def set_callback(self, event: str, callback):
""" """
> Overrides the existing callbacks with the given callback. Overrides the existing callbacks with the given callback.
""" """
self.callbacks[event] = [callback] self.callbacks[event] = [callback]
@ -194,7 +194,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size): def _setup_train(self, rank, world_size):
""" """
> Builds dataloaders and optimizer on correct rank process. Builds dataloaders and optimizer on correct rank process.
""" """
# model # model
self.run_callbacks("on_pretrain_routine_start") self.run_callbacks("on_pretrain_routine_start")
@ -383,13 +383,13 @@ class BaseTrainer:
def get_dataset(self, data): def get_dataset(self, data):
""" """
> Get train, val path from data dict if it exists. Returns None if data format is not recognized. Get train, val path from data dict if it exists. Returns None if data format is not recognized.
""" """
return data["train"], data.get("val") or data.get("test") return data["train"], data.get("val") or data.get("test")
def setup_model(self): def setup_model(self):
""" """
> load/create/download model for any task. load/create/download model for any task.
""" """
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return return
@ -415,13 +415,13 @@ class BaseTrainer:
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
""" """
> Allows custom preprocessing model inputs and ground truths depending on task type. Allows custom preprocessing model inputs and ground truths depending on task type.
""" """
return batch return batch
def validate(self): def validate(self):
""" """
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key. Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
""" """
metrics = self.validator(self) metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -431,7 +431,7 @@ class BaseTrainer:
def log(self, text, rank=-1): def log(self, text, rank=-1):
""" """
> Logs the given text to given ranks process if provided, otherwise logs to all ranks. Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args" Args"
text (str): text to log text (str): text to log
@ -449,13 +449,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0): def get_dataloader(self, dataset_path, batch_size=16, rank=0):
""" """
> Returns dataloader derived from torch.data.Dataloader. Returns dataloader derived from torch.data.Dataloader.
""" """
raise NotImplementedError("get_dataloader function not implemented in trainer") raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch): def criterion(self, preds, batch):
""" """
> Returns loss and individual loss items as Tensor. Returns loss and individual loss items as Tensor.
""" """
raise NotImplementedError("criterion function not implemented in trainer") raise NotImplementedError("criterion function not implemented in trainer")
@ -543,7 +543,7 @@ class BaseTrainer:
@staticmethod @staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
""" """
> Builds an optimizer with the specified parameters and parameter groups. Builds an optimizer with the specified parameters and parameter groups.
Args: Args:
model (nn.Module): model to optimize model (nn.Module): model to optimize

@ -10,7 +10,7 @@ except (ModuleNotFoundError, ImportError):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8",) experiment = comet_ml.Experiment(project_name=trainer.args.project or "YOLOv8")
experiment.log_parameters(dict(trainer.args)) experiment.log_parameters(dict(trainer.args))

@ -12,7 +12,7 @@ from zipfile import ZipFile
import requests import requests
import torch import torch
from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils import LOGGER, SETTINGS
def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
@ -59,7 +59,11 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
file = Path(str(file).strip().replace("'", '')) file = Path(str(file).strip().replace("'", ''))
if not file.exists(): if file.exists():
return str(file)
elif (SETTINGS['weights_dir'] / file).exists():
return str(SETTINGS['weights_dir'] / file)
else:
# URL specified # URL specified
name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc. name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
if str(file).startswith(('http:/', 'https:/')): # download if str(file).startswith(('http:/', 'https:/')): # download
@ -94,7 +98,7 @@ def attempt_download(file, repo='ultralytics/assets', release='v0.0.0'):
min_bytes=1E5, min_bytes=1E5,
error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}')
return str(file) return str(file)
def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3): def download(url, dir=Path.cwd(), unzip=True, delete=True, curl=False, threads=1, retry=3):

@ -58,10 +58,9 @@ class ClassificationPredictor(BasePredictor):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def predict(cfg): def predict(cfg):
cfg.model = cfg.model or "squeezenet1_0" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size cfg.imgsz = check_imgsz(cfg.imgsz, min_dim=2) # check image size
cfg.source = cfg.source if cfg.source is not None else ROOT / "assets" cfg.source = cfg.source if cfg.source is not None else ROOT / "assets"
predictor = ClassificationPredictor(cfg) predictor = ClassificationPredictor(cfg)
predictor.predict_cli() predictor.predict_cli()

@ -136,7 +136,7 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18" cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
cfg.lr0 = 0.1 cfg.lr0 = 0.1
cfg.weight_decay = 5e-5 cfg.weight_decay = 5e-5
@ -151,10 +151,4 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
"""
yolo task=classify mode=train model=yolov8n-cls.pt data=mnist160 epochs=10 imgsz=32
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
"""
train() train()

@ -48,8 +48,8 @@ class ClassificationValidator(BaseValidator):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg): def val(cfg):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "imagenette160" cfg.data = cfg.data or "imagenette160"
cfg.model = cfg.model or "resnet18"
validator = ClassificationValidator(args=cfg) validator = ClassificationValidator(args=cfg)
validator(model=cfg.model) validator(model=cfg.model)

@ -197,7 +197,7 @@ class Loss:
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "yolov8n.yaml" cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' cfg.device = cfg.device if cfg.device is not None else ''
# trainer = DetectionTrainer(cfg) # trainer = DetectionTrainer(cfg)
@ -208,11 +208,4 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/detect/train.py model=yolov8n.yaml data=coco128 epochs=100 imgsz=640
TODO:
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=100
"""
train() train()

@ -234,6 +234,7 @@ class DetectionValidator(BaseValidator):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def val(cfg): def val(cfg):
cfg.model = cfg.model or "yolov8n.pt"
cfg.data = cfg.data or "coco128.yaml" cfg.data = cfg.data or "coco128.yaml"
validator = DetectionValidator(args=cfg) validator = DetectionValidator(args=cfg)
validator(model=cfg.model) validator(model=cfg.model)

@ -143,7 +143,7 @@ class SegLoss(Loss):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name) @hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "yolov8n-seg.yaml" cfg.model = cfg.model or "yolov8n-seg.pt"
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
cfg.device = cfg.device if cfg.device is not None else '' cfg.device = cfg.device if cfg.device is not None else ''
# trainer = SegmentationTrainer(cfg) # trainer = SegmentationTrainer(cfg)
@ -154,11 +154,4 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/segment/train.py model=yolov8n-seg.yaml data=coco128-segments epochs=100 imgsz=640
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10
"""
train() train()

@ -114,8 +114,9 @@ class SegmentationValidator(DetectionValidator):
masks=True) masks=True)
if self.args.plots: if self.args.plots:
self.confusion_matrix.process_batch(predn, labelsn) self.confusion_matrix.process_batch(predn, labelsn)
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:,
5], cls.squeeze(-1))) # conf, pcls, tcls # Append correct_masks, correct_boxes, pconf, pcls, tcls
self.stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], cls.squeeze(-1)))
pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8) pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
if self.args.plots and self.batch_i < 3: if self.args.plots and self.batch_i < 3:

Loading…
Cancel
Save