`ultralytics 8.0.81` single-line docstring updates (#2061)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 2 years ago committed by GitHub
parent 5bce1c3021
commit a38f227672
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,6 +42,7 @@ the benchmarks to their specific needs and compare the performance of different
| `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml | | `model` | `None` | path to model file, i.e. yolov8n.pt, yolov8n.yaml |
| `imgsz` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) | | `imgsz` | `640` | image size as scalar or (h, w) list, i.e. (640, 480) |
| `half` | `False` | FP16 quantization | | `half` | `False` | FP16 quantization |
| `int8` | `False` | INT8 quantization |
| `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu | | `device` | `None` | device to run on, i.e. cuda device=0 or device=0,1,2,3 or device=cpu |
| `hard_fail` | `False` | do not continue on error (bool), or val floor threshold (float) | | `hard_fail` | `False` | do not continue on error (bool), or val floor threshold (float) |

@ -1,11 +1,11 @@
# iOSDetectModel # Exporter
--- ---
:::ultralytics.yolo.engine.exporter.iOSDetectModel :::ultralytics.yolo.engine.exporter.Exporter
<br><br> <br><br>
# Exporter # iOSDetectModel
--- ---
:::ultralytics.yolo.engine.exporter.Exporter :::ultralytics.yolo.engine.exporter.iOSDetectModel
<br><br> <br><br>
# export_formats # export_formats

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
site_name: YOLOv8 Docs site_name: Ultralytics YOLOv8 Docs
site_url: https://docs.ultralytics.com
repo_url: https://github.com/ultralytics/ultralytics repo_url: https://github.com/ultralytics/ultralytics
edit_uri: https://github.com/ultralytics/ultralytics/tree/main/docs edit_uri: https://github.com/ultralytics/ultralytics/tree/main/docs
repo_name: ultralytics/ultralytics repo_name: ultralytics/ultralytics

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.80' __version__ = '8.0.81'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.engine.model import YOLO

@ -130,6 +130,7 @@ def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbos
@TryExcept(verbose=verbose) @TryExcept(verbose=verbose)
def func(func_method, func_url, **func_kwargs): def func(func_method, func_url, **func_kwargs):
"""Make HTTP requests with retries and timeouts, with optional progress tracking."""
r = None # response r = None # response
t0 = time.time() # initial time for timer t0 = time.time() # initial time for timer
for i in range(retry + 1): for i in range(retry + 1):

@ -202,6 +202,7 @@ class AutoBackend(nn.Module):
from ultralytics.yolo.engine.exporter import gd_outputs from ultralytics.yolo.engine.exporter import gd_outputs
def wrap_frozen_graph(gd, inputs, outputs): def wrap_frozen_graph(gd, inputs, outputs):
"""Wrap frozen graphs for deployment."""
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
ge = x.graph.as_graph_element ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs)) return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
@ -427,6 +428,7 @@ class AutoBackend(nn.Module):
@staticmethod @staticmethod
def _apply_default_class_names(data): def _apply_default_class_names(data):
"""Applies default class names to an input YAML file or returns numerical class names."""
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names'] return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors return {i: f'class{i}' for i in range(999)} # return default if above errors

@ -34,6 +34,7 @@ class AutoShape(nn.Module):
amp = False # Automatic Mixed Precision (AMP) inference amp = False # Automatic Mixed Precision (AMP) inference
def __init__(self, model, verbose=True): def __init__(self, model, verbose=True):
"""Initializes object and copies attributes from model object."""
super().__init__() super().__init__()
if verbose: if verbose:
LOGGER.info('Adding AutoShape... ') LOGGER.info('Adding AutoShape... ')
@ -125,6 +126,7 @@ class AutoShape(nn.Module):
class Detections: class Detections:
# YOLOv8 detections class for inference results # YOLOv8 detections class for inference results
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None): def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
"""Initialize object attributes for YOLO detection results."""
super().__init__() super().__init__()
d = pred[0].device # device d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
@ -142,6 +144,7 @@ class Detections:
self.s = tuple(shape) # inference BCHW shape self.s = tuple(shape) # inference BCHW shape
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')): def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
"""Return performance metrics and optionally cropped/save images or results."""
s, crops = '', [] s, crops = '', []
for i, (im, pred) in enumerate(zip(self.ims, self.pred)): for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
@ -187,17 +190,21 @@ class Detections:
return crops return crops
def show(self, labels=True): def show(self, labels=True):
"""Displays YOLO results with detected bounding boxes."""
self._run(show=True, labels=labels) # show results self._run(show=True, labels=labels) # show results
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False): def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
"""Save detection results with optional labels to specified directory."""
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
self._run(save=True, labels=labels, save_dir=save_dir) # save results self._run(save=True, labels=labels, save_dir=save_dir) # save results
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False): def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
"""Crops images into detections and saves them if 'save' is True."""
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
return self._run(crop=True, save=save, save_dir=save_dir) # crop results return self._run(crop=True, save=save, save_dir=save_dir) # crop results
def render(self, labels=True): def render(self, labels=True):
"""Renders detected objects and returns images."""
self._run(render=True, labels=labels) # render results self._run(render=True, labels=labels) # render results
return self.ims return self.ims
@ -222,6 +229,7 @@ class Detections:
return x return x
def print(self): def print(self):
"""Print the results of the `self._run()` function."""
LOGGER.info(self.__str__()) LOGGER.info(self.__str__())
def __len__(self): # override len(results) def __len__(self): # override len(results)
@ -231,4 +239,5 @@ class Detections:
return self._run(pprint=True) # print results return self._run(pprint=True) # print results
def __repr__(self): def __repr__(self):
"""Returns a printable representation of the object."""
return f'YOLOv8 {self.__class__} instance\n' + self.__str__() return f'YOLOv8 {self.__class__} instance\n' + self.__str__()

@ -25,15 +25,18 @@ class Conv(nn.Module):
default_act = nn.SiLU() # default activation default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__() super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2) self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x): def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x))) return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x): def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x)) return self.act(self.conv(x))
@ -56,15 +59,18 @@ class ConvTranspose(nn.Module):
default_act = nn.SiLU() # default activation default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True): def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
"""Initialize ConvTranspose2d layer with batch normalization and activation function."""
super().__init__() super().__init__()
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn) self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity() self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x): def forward(self, x):
"""Applies transposed convolutions, batch normalization and activation to input."""
return self.act(self.bn(self.conv_transpose(x))) return self.act(self.bn(self.conv_transpose(x)))
def forward_fuse(self, x): def forward_fuse(self, x):
"""Applies activation and convolution transpose operation to input."""
return self.act(self.conv_transpose(x)) return self.act(self.conv_transpose(x))
@ -75,6 +81,7 @@ class DFL(nn.Module):
""" """
def __init__(self, c1=16): def __init__(self, c1=16):
"""Initialize a convolutional layer with a given number of input channels."""
super().__init__() super().__init__()
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
x = torch.arange(c1, dtype=torch.float) x = torch.arange(c1, dtype=torch.float)
@ -82,6 +89,7 @@ class DFL(nn.Module):
self.c1 = c1 self.c1 = c1
def forward(self, x): def forward(self, x):
"""Applies a transformer layer on input tensor 'x' and returns a tensor."""
b, c, a = x.shape # batch, channels, anchors b, c, a = x.shape # batch, channels, anchors
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
@ -91,6 +99,7 @@ class TransformerLayer(nn.Module):
"""Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance).""" """Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)."""
def __init__(self, c, num_heads): def __init__(self, c, num_heads):
"""Initializes a self-attention mechanism using linear transformations and multi-head attention."""
super().__init__() super().__init__()
self.q = nn.Linear(c, c, bias=False) self.q = nn.Linear(c, c, bias=False)
self.k = nn.Linear(c, c, bias=False) self.k = nn.Linear(c, c, bias=False)
@ -100,6 +109,7 @@ class TransformerLayer(nn.Module):
self.fc2 = nn.Linear(c, c, bias=False) self.fc2 = nn.Linear(c, c, bias=False)
def forward(self, x): def forward(self, x):
"""Apply a transformer block to the input x and return the output."""
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
x = self.fc2(self.fc1(x)) + x x = self.fc2(self.fc1(x)) + x
return x return x
@ -109,6 +119,7 @@ class TransformerBlock(nn.Module):
"""Vision Transformer https://arxiv.org/abs/2010.11929.""" """Vision Transformer https://arxiv.org/abs/2010.11929."""
def __init__(self, c1, c2, num_heads, num_layers): def __init__(self, c1, c2, num_heads, num_layers):
"""Initialize a Transformer module with position embedding and specified number of heads and layers."""
super().__init__() super().__init__()
self.conv = None self.conv = None
if c1 != c2: if c1 != c2:
@ -118,6 +129,7 @@ class TransformerBlock(nn.Module):
self.c2 = c2 self.c2 = c2
def forward(self, x): def forward(self, x):
"""Forward propagates the input through the bottleneck module."""
if self.conv is not None: if self.conv is not None:
x = self.conv(x) x = self.conv(x)
b, _, w, h = x.shape b, _, w, h = x.shape
@ -136,6 +148,7 @@ class Bottleneck(nn.Module):
self.add = shortcut and c1 == c2 self.add = shortcut and c1 == c2
def forward(self, x): def forward(self, x):
"""'forward()' applies the YOLOv5 FPN to input data."""
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
@ -154,6 +167,7 @@ class BottleneckCSP(nn.Module):
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x): def forward(self, x):
"""Applies a CSP bottleneck with 3 convolutions."""
y1 = self.cv3(self.m(self.cv1(x))) y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x) y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1)))) return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
@ -171,6 +185,7 @@ class C3(nn.Module):
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n)))
def forward(self, x): def forward(self, x):
"""Forward pass through the CSP bottleneck with 2 convolutions."""
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
@ -186,6 +201,7 @@ class C2(nn.Module):
self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
def forward(self, x): def forward(self, x):
"""Forward pass through the CSP bottleneck with 2 convolutions."""
a, b = self.cv1(x).chunk(2, 1) a, b = self.cv1(x).chunk(2, 1)
return self.cv2(torch.cat((self.m(a), b), 1)) return self.cv2(torch.cat((self.m(a), b), 1))
@ -201,11 +217,13 @@ class C2f(nn.Module):
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x): def forward(self, x):
"""Forward pass of a YOLOv5 CSPDarknet backbone layer."""
y = list(self.cv1(x).chunk(2, 1)) y = list(self.cv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.m) y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1)) return self.cv2(torch.cat(y, 1))
def forward_split(self, x): def forward_split(self, x):
"""Applies spatial attention to module's input."""
y = list(self.cv1(x).split((self.c, self.c), 1)) y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in self.m) y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1)) return self.cv2(torch.cat(y, 1))
@ -228,6 +246,7 @@ class SpatialAttention(nn.Module):
"""Spatial-attention module.""" """Spatial-attention module."""
def __init__(self, kernel_size=7): def __init__(self, kernel_size=7):
"""Initialize Spatial-attention module with kernel size argument."""
super().__init__() super().__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7' assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1 padding = 3 if kernel_size == 7 else 1
@ -235,6 +254,7 @@ class SpatialAttention(nn.Module):
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
def forward(self, x): def forward(self, x):
"""Apply channel and spatial attention on input for feature recalibration."""
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1))) return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
@ -247,6 +267,7 @@ class CBAM(nn.Module):
self.spatial_attention = SpatialAttention(kernel_size) self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x): def forward(self, x):
"""Applies the forward pass through C1 module."""
return self.spatial_attention(self.channel_attention(x)) return self.spatial_attention(self.channel_attention(x))
@ -259,6 +280,7 @@ class C1(nn.Module):
self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
def forward(self, x): def forward(self, x):
"""Applies cross-convolutions to input in the C3 module."""
y = self.cv1(x) y = self.cv1(x)
return self.m(y) + y return self.m(y) + y
@ -267,6 +289,7 @@ class C3x(C3):
"""C3 module with cross-convolutions.""" """C3 module with cross-convolutions."""
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
"""Initialize C3TR instance and set default parameters."""
super().__init__(c1, c2, n, shortcut, g, e) super().__init__(c1, c2, n, shortcut, g, e)
self.c_ = int(c2 * e) self.c_ = int(c2 * e)
self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
@ -276,6 +299,7 @@ class C3TR(C3):
"""C3 module with TransformerBlock().""" """C3 module with TransformerBlock()."""
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
"""Initialize C3Ghost module with GhostBottleneck()."""
super().__init__(c1, c2, n, shortcut, g, e) super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) c_ = int(c2 * e)
self.m = TransformerBlock(c_, c_, 4, n) self.m = TransformerBlock(c_, c_, 4, n)
@ -285,6 +309,7 @@ class C3Ghost(C3):
"""C3 module with GhostBottleneck().""" """C3 module with GhostBottleneck()."""
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
"""Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling."""
super().__init__(c1, c2, n, shortcut, g, e) super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
@ -294,6 +319,7 @@ class SPP(nn.Module):
"""Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.""" """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729."""
def __init__(self, c1, c2, k=(5, 9, 13)): def __init__(self, c1, c2, k=(5, 9, 13)):
"""Initialize the SPP layer with input/output channels and pooling kernel sizes."""
super().__init__() super().__init__()
c_ = c1 // 2 # hidden channels c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1) self.cv1 = Conv(c1, c_, 1, 1)
@ -301,6 +327,7 @@ class SPP(nn.Module):
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
def forward(self, x): def forward(self, x):
"""Forward pass of the SPP layer, performing spatial pyramid pooling."""
x = self.cv1(x) x = self.cv1(x)
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
@ -316,6 +343,7 @@ class SPPF(nn.Module):
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x): def forward(self, x):
"""Forward pass through Ghost Convolution block."""
x = self.cv1(x) x = self.cv1(x)
y1 = self.m(x) y1 = self.m(x)
y2 = self.m(y1) y2 = self.m(y1)
@ -345,6 +373,7 @@ class GhostConv(nn.Module):
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act) self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
def forward(self, x): def forward(self, x):
"""Forward propagation through a Ghost Bottleneck layer with skip connection."""
y = self.cv1(x) y = self.cv1(x)
return torch.cat((y, self.cv2(y)), 1) return torch.cat((y, self.cv2(y)), 1)
@ -363,6 +392,7 @@ class GhostBottleneck(nn.Module):
act=False)) if s == 2 else nn.Identity() act=False)) if s == 2 else nn.Identity()
def forward(self, x): def forward(self, x):
"""Applies skip connection and concatenation to input tensor."""
return self.conv(x) + self.shortcut(x) return self.conv(x) + self.shortcut(x)
@ -370,10 +400,12 @@ class Concat(nn.Module):
"""Concatenate a list of tensors along dimension.""" """Concatenate a list of tensors along dimension."""
def __init__(self, dimension=1): def __init__(self, dimension=1):
"""Concatenates a list of tensors along a specified dimension."""
super().__init__() super().__init__()
self.d = dimension self.d = dimension
def forward(self, x): def forward(self, x):
"""Forward pass for the YOLOv8 mask Proto module."""
return torch.cat(x, self.d) return torch.cat(x, self.d)
@ -388,6 +420,7 @@ class Proto(nn.Module):
self.cv3 = Conv(c_, c2) self.cv3 = Conv(c_, c2)
def forward(self, x): def forward(self, x):
"""Performs a forward pass through layers using an upsampled input image."""
return self.cv3(self.cv2(self.upsample(self.cv1(x)))) return self.cv3(self.cv2(self.upsample(self.cv1(x))))
@ -395,9 +428,11 @@ class Ensemble(nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""
def __init__(self): def __init__(self):
"""Initialize an ensemble of models."""
super().__init__() super().__init__()
def forward(self, x, augment=False, profile=False, visualize=False): def forward(self, x, augment=False, profile=False, visualize=False):
"""Function generates the YOLOv5 network's final layer."""
y = [module(x, augment, profile, visualize)[0] for module in self] y = [module(x, augment, profile, visualize)[0] for module in self]
# y = torch.stack(y).max(0)[0] # max ensemble # y = torch.stack(y).max(0)[0] # max ensemble
# y = torch.stack(y).mean(0) # mean ensemble # y = torch.stack(y).mean(0) # mean ensemble
@ -430,6 +465,7 @@ class Detect(nn.Module):
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x): def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
shape = x[0].shape # BCHW shape = x[0].shape # BCHW
for i in range(self.nl): for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
@ -463,6 +499,7 @@ class Segment(Detect):
"""YOLOv8 Segment head for segmentation models.""" """YOLOv8 Segment head for segmentation models."""
def __init__(self, nc=80, nm=32, npr=256, ch=()): def __init__(self, nc=80, nm=32, npr=256, ch=()):
"""Initialize the YOLO model attributes such as the number of masks, prototypes, and the convolution layers."""
super().__init__(nc, ch) super().__init__(nc, ch)
self.nm = nm # number of masks self.nm = nm # number of masks
self.npr = npr # number of protos self.npr = npr # number of protos
@ -473,6 +510,7 @@ class Segment(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
def forward(self, x): def forward(self, x):
"""Return model outputs and mask coefficients if training, otherwise return outputs and mask coefficients."""
p = self.proto(x[0]) # mask protos p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size bs = p.shape[0] # batch size
@ -487,6 +525,7 @@ class Pose(Detect):
"""YOLOv8 Pose head for keypoints models.""" """YOLOv8 Pose head for keypoints models."""
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
"""Initialize YOLO network with default parameters and Convolutional Layers."""
super().__init__(nc, ch) super().__init__(nc, ch)
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible) self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
@ -496,6 +535,7 @@ class Pose(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch) self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
def forward(self, x): def forward(self, x):
"""Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size bs = x[0].shape[0] # batch size
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
x = self.detect(self, x) x = self.detect(self, x)
@ -505,6 +545,7 @@ class Pose(Detect):
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt)) return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
def kpts_decode(self, kpts): def kpts_decode(self, kpts):
"""Decodes keypoints."""
ndim = self.kpt_shape[1] ndim = self.kpt_shape[1]
y = kpts.clone() y = kpts.clone()
if ndim == 3: if ndim == 3:
@ -526,6 +567,7 @@ class Classify(nn.Module):
self.linear = nn.Linear(c_, c2) # to x(b,c2) self.linear = nn.Linear(c_, c2) # to x(b,c2)
def forward(self, x): def forward(self, x):
"""Performs a forward pass of the YOLO model on input image data."""
if isinstance(x, list): if isinstance(x, list):
x = torch.cat(x, 1) x = torch.cat(x, 1)
x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1))) x = self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))

@ -199,11 +199,13 @@ class DetectionModel(BaseModel):
LOGGER.info('') LOGGER.info('')
def forward(self, x, augment=False, profile=False, visualize=False): def forward(self, x, augment=False, profile=False, visualize=False):
"""Run forward pass on input image(s) with optional augmentation and profiling."""
if augment: if augment:
return self._forward_augment(x) # augmented inference, None return self._forward_augment(x) # augmented inference, None
return self._forward_once(x, profile, visualize) # single-scale inference, train return self._forward_once(x, profile, visualize) # single-scale inference, train
def _forward_augment(self, x): def _forward_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
img_size = x.shape[-2:] # height, width img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr) f = [None, 3, None] # flips (2-ud, 3-lr)
@ -244,9 +246,11 @@ class SegmentationModel(DetectionModel):
"""YOLOv8 segmentation model.""" """YOLOv8 segmentation model."""
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True): def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
"""Initialize YOLOv8 segmentation model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def _forward_augment(self, x): def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')) raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
@ -254,6 +258,7 @@ class PoseModel(DetectionModel):
"""YOLOv8 pose model.""" """YOLOv8 pose model."""
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True): def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
"""Initialize YOLOv8 Pose model."""
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
cfg = yaml_model_load(cfg) # load model YAML cfg = yaml_model_load(cfg) # load model YAML
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']): if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
@ -292,6 +297,7 @@ class ClassificationModel(BaseModel):
self.nc = nc self.nc = nc
def _from_yaml(self, cfg, ch, nc, verbose): def _from_yaml(self, cfg, ch, nc, verbose):
"""Set YOLOv8 model configurations and define the model architecture."""
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
# Define model # Define model
@ -501,6 +507,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
def yaml_model_load(path): def yaml_model_load(path):
"""Load a YOLOv8 model from a YAML file."""
import re import re
path = Path(path) path = Path(path)

@ -37,6 +37,7 @@ def on_predict_start(predictor, persist=False):
def on_predict_postprocess_end(predictor): def on_predict_postprocess_end(predictor):
"""Postprocess detected boxes and update with object tracking."""
bs = predictor.dataset.bs bs = predictor.dataset.bs
im0s = predictor.batch[2] im0s = predictor.batch[2]
im0s = im0s if isinstance(im0s, list) else [im0s] im0s = im0s if isinstance(im0s, list) else [im0s]

@ -6,6 +6,8 @@ import numpy as np
class TrackState: class TrackState:
"""Enumeration of possible object tracking states."""
New = 0 New = 0
Tracked = 1 Tracked = 1
Lost = 2 Lost = 2
@ -13,6 +15,8 @@ class TrackState:
class BaseTrack: class BaseTrack:
"""Base class for object tracking, handling basic track attributes and operations."""
_count = 0 _count = 0
track_id = 0 track_id = 0
@ -32,28 +36,36 @@ class BaseTrack:
@property @property
def end_frame(self): def end_frame(self):
"""Return the last frame ID of the track."""
return self.frame_id return self.frame_id
@staticmethod @staticmethod
def next_id(): def next_id():
"""Increment and return the global track ID counter."""
BaseTrack._count += 1 BaseTrack._count += 1
return BaseTrack._count return BaseTrack._count
def activate(self, *args): def activate(self, *args):
"""Activate the track with the provided arguments."""
raise NotImplementedError raise NotImplementedError
def predict(self): def predict(self):
"""Predict the next state of the track."""
raise NotImplementedError raise NotImplementedError
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
"""Update the track with new observations."""
raise NotImplementedError raise NotImplementedError
def mark_lost(self): def mark_lost(self):
"""Mark the track as lost."""
self.state = TrackState.Lost self.state = TrackState.Lost
def mark_removed(self): def mark_removed(self):
"""Mark the track as removed."""
self.state = TrackState.Removed self.state = TrackState.Removed
@staticmethod @staticmethod
def reset_id(): def reset_id():
"""Reset the global track ID counter."""
BaseTrack._count = 0 BaseTrack._count = 0

@ -15,6 +15,7 @@ class BOTrack(STrack):
shared_kalman = KalmanFilterXYWH() shared_kalman = KalmanFilterXYWH()
def __init__(self, tlwh, score, cls, feat=None, feat_history=50): def __init__(self, tlwh, score, cls, feat=None, feat_history=50):
"""Initialize YOLOv8 object with temporal parameters, such as feature history, alpha and current features."""
super().__init__(tlwh, score, cls) super().__init__(tlwh, score, cls)
self.smooth_feat = None self.smooth_feat = None
@ -25,6 +26,7 @@ class BOTrack(STrack):
self.alpha = 0.9 self.alpha = 0.9
def update_features(self, feat): def update_features(self, feat):
"""Update features vector and smooth it using exponential moving average."""
feat /= np.linalg.norm(feat) feat /= np.linalg.norm(feat)
self.curr_feat = feat self.curr_feat = feat
if self.smooth_feat is None: if self.smooth_feat is None:
@ -35,6 +37,7 @@ class BOTrack(STrack):
self.smooth_feat /= np.linalg.norm(self.smooth_feat) self.smooth_feat /= np.linalg.norm(self.smooth_feat)
def predict(self): def predict(self):
"""Predicts the mean and covariance using Kalman filter."""
mean_state = self.mean.copy() mean_state = self.mean.copy()
if self.state != TrackState.Tracked: if self.state != TrackState.Tracked:
mean_state[6] = 0 mean_state[6] = 0
@ -43,11 +46,13 @@ class BOTrack(STrack):
self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a track with updated features and optionally assigns a new ID."""
if new_track.curr_feat is not None: if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
super().re_activate(new_track, frame_id, new_id) super().re_activate(new_track, frame_id, new_id)
def update(self, new_track, frame_id): def update(self, new_track, frame_id):
"""Update the YOLOv8 instance with new track and frame ID."""
if new_track.curr_feat is not None: if new_track.curr_feat is not None:
self.update_features(new_track.curr_feat) self.update_features(new_track.curr_feat)
super().update(new_track, frame_id) super().update(new_track, frame_id)
@ -65,6 +70,7 @@ class BOTrack(STrack):
@staticmethod @staticmethod
def multi_predict(stracks): def multi_predict(stracks):
"""Predicts the mean and covariance of multiple object tracks using shared Kalman filter."""
if len(stracks) <= 0: if len(stracks) <= 0:
return return
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -79,6 +85,7 @@ class BOTrack(STrack):
stracks[i].covariance = cov stracks[i].covariance = cov
def convert_coords(self, tlwh): def convert_coords(self, tlwh):
"""Converts Top-Left-Width-Height bounding box coordinates to X-Y-Width-Height format."""
return self.tlwh_to_xywh(tlwh) return self.tlwh_to_xywh(tlwh)
@staticmethod @staticmethod
@ -94,6 +101,7 @@ class BOTrack(STrack):
class BOTSORT(BYTETracker): class BOTSORT(BYTETracker):
def __init__(self, args, frame_rate=30): def __init__(self, args, frame_rate=30):
"""Initialize YOLOv8 object with ReID module and GMC algorithm."""
super().__init__(args, frame_rate) super().__init__(args, frame_rate)
# ReID module # ReID module
self.proximity_thresh = args.proximity_thresh self.proximity_thresh = args.proximity_thresh
@ -106,9 +114,11 @@ class BOTSORT(BYTETracker):
self.gmc = GMC(method=args.cmc_method) self.gmc = GMC(method=args.cmc_method)
def get_kalmanfilter(self): def get_kalmanfilter(self):
"""Returns an instance of KalmanFilterXYWH for object tracking."""
return KalmanFilterXYWH() return KalmanFilterXYWH()
def init_track(self, dets, scores, cls, img=None): def init_track(self, dets, scores, cls, img=None):
"""Initialize track with detections, scores, and classes."""
if len(dets) == 0: if len(dets) == 0:
return [] return []
if self.args.with_reid and self.encoder is not None: if self.args.with_reid and self.encoder is not None:
@ -118,6 +128,7 @@ class BOTSORT(BYTETracker):
return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections return [BOTrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] # detections
def get_dists(self, tracks, detections): def get_dists(self, tracks, detections):
"""Get distances between tracks and detections using IoU and (optionally) ReID embeddings."""
dists = matching.iou_distance(tracks, detections) dists = matching.iou_distance(tracks, detections)
dists_mask = (dists > self.proximity_thresh) dists_mask = (dists > self.proximity_thresh)
@ -133,4 +144,5 @@ class BOTSORT(BYTETracker):
return dists return dists
def multi_predict(self, tracks): def multi_predict(self, tracks):
"""Predict and track multiple objects with YOLOv8 model."""
BOTrack.multi_predict(tracks) BOTrack.multi_predict(tracks)

@ -23,6 +23,7 @@ class STrack(BaseTrack):
self.idx = tlwh[-1] self.idx = tlwh[-1]
def predict(self): def predict(self):
"""Predicts mean and covariance using Kalman filter."""
mean_state = self.mean.copy() mean_state = self.mean.copy()
if self.state != TrackState.Tracked: if self.state != TrackState.Tracked:
mean_state[7] = 0 mean_state[7] = 0
@ -30,6 +31,7 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def multi_predict(stracks): def multi_predict(stracks):
"""Perform multi-object predictive tracking using Kalman filter for given stracks."""
if len(stracks) <= 0: if len(stracks) <= 0:
return return
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
@ -44,6 +46,7 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def multi_gmc(stracks, H=np.eye(2, 3)): def multi_gmc(stracks, H=np.eye(2, 3)):
"""Update state tracks positions and covariances using a homography matrix."""
if len(stracks) > 0: if len(stracks) > 0:
multi_mean = np.asarray([st.mean.copy() for st in stracks]) multi_mean = np.asarray([st.mean.copy() for st in stracks])
multi_covariance = np.asarray([st.covariance for st in stracks]) multi_covariance = np.asarray([st.covariance for st in stracks])
@ -74,6 +77,7 @@ class STrack(BaseTrack):
self.start_frame = frame_id self.start_frame = frame_id
def re_activate(self, new_track, frame_id, new_id=False): def re_activate(self, new_track, frame_id, new_id=False):
"""Reactivates a previously lost track with a new detection."""
self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance, self.mean, self.covariance = self.kalman_filter.update(self.mean, self.covariance,
self.convert_coords(new_track.tlwh)) self.convert_coords(new_track.tlwh))
self.tracklet_len = 0 self.tracklet_len = 0
@ -107,6 +111,7 @@ class STrack(BaseTrack):
self.idx = new_track.idx self.idx = new_track.idx
def convert_coords(self, tlwh): def convert_coords(self, tlwh):
"""Convert a bounding box's top-left-width-height format to its x-y-angle-height equivalent."""
return self.tlwh_to_xyah(tlwh) return self.tlwh_to_xyah(tlwh)
@property @property
@ -142,23 +147,27 @@ class STrack(BaseTrack):
@staticmethod @staticmethod
def tlbr_to_tlwh(tlbr): def tlbr_to_tlwh(tlbr):
"""Converts top-left bottom-right format to top-left width height format."""
ret = np.asarray(tlbr).copy() ret = np.asarray(tlbr).copy()
ret[2:] -= ret[:2] ret[2:] -= ret[:2]
return ret return ret
@staticmethod @staticmethod
def tlwh_to_tlbr(tlwh): def tlwh_to_tlbr(tlwh):
"""Converts tlwh bounding box format to tlbr format."""
ret = np.asarray(tlwh).copy() ret = np.asarray(tlwh).copy()
ret[2:] += ret[:2] ret[2:] += ret[:2]
return ret return ret
def __repr__(self): def __repr__(self):
"""Return a string representation of the BYTETracker object with start and end frames and track ID."""
return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})' return f'OT_{self.track_id}_({self.start_frame}-{self.end_frame})'
class BYTETracker: class BYTETracker:
def __init__(self, args, frame_rate=30): def __init__(self, args, frame_rate=30):
"""Initialize a YOLOv8 object to track objects with given arguments and frame rate."""
self.tracked_stracks = [] # type: list[STrack] self.tracked_stracks = [] # type: list[STrack]
self.lost_stracks = [] # type: list[STrack] self.lost_stracks = [] # type: list[STrack]
self.removed_stracks = [] # type: list[STrack] self.removed_stracks = [] # type: list[STrack]
@ -170,6 +179,7 @@ class BYTETracker:
self.reset_id() self.reset_id()
def update(self, results, img=None): def update(self, results, img=None):
"""Updates object tracker with new detections and returns tracked object bounding boxes."""
self.frame_id += 1 self.frame_id += 1
activated_starcks = [] activated_starcks = []
refind_stracks = [] refind_stracks = []
@ -285,12 +295,15 @@ class BYTETracker:
dtype=np.float32) dtype=np.float32)
def get_kalmanfilter(self): def get_kalmanfilter(self):
"""Returns a Kalman filter object for tracking bounding boxes."""
return KalmanFilterXYAH() return KalmanFilterXYAH()
def init_track(self, dets, scores, cls, img=None): def init_track(self, dets, scores, cls, img=None):
"""Initialize object tracking with detections and scores using STrack algorithm."""
return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections return [STrack(xyxy, s, c) for (xyxy, s, c) in zip(dets, scores, cls)] if len(dets) else [] # detections
def get_dists(self, tracks, detections): def get_dists(self, tracks, detections):
"""Calculates the distance between tracks and detections using IOU and fuses scores."""
dists = matching.iou_distance(tracks, detections) dists = matching.iou_distance(tracks, detections)
# TODO: mot20 # TODO: mot20
# if not self.args.mot20: # if not self.args.mot20:
@ -298,13 +311,16 @@ class BYTETracker:
return dists return dists
def multi_predict(self, tracks): def multi_predict(self, tracks):
"""Returns the predicted tracks using the YOLOv8 network."""
STrack.multi_predict(tracks) STrack.multi_predict(tracks)
def reset_id(self): def reset_id(self):
"""Resets the ID counter of STrack."""
STrack.reset_id() STrack.reset_id()
@staticmethod @staticmethod
def joint_stracks(tlista, tlistb): def joint_stracks(tlista, tlistb):
"""Combine two lists of stracks into a single one."""
exists = {} exists = {}
res = [] res = []
for t in tlista: for t in tlista:
@ -332,6 +348,7 @@ class BYTETracker:
@staticmethod @staticmethod
def remove_duplicate_stracks(stracksa, stracksb): def remove_duplicate_stracks(stracksa, stracksb):
"""Remove duplicate stracks with non-maximum IOU distance."""
pdist = matching.iou_distance(stracksa, stracksb) pdist = matching.iou_distance(stracksa, stracksb)
pairs = np.where(pdist < 0.15) pairs = np.where(pdist < 0.15)
dupa, dupb = [], [] dupa, dupb = [], []

@ -11,6 +11,7 @@ from ultralytics.yolo.utils import LOGGER
class GMC: class GMC:
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None): def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
"""Initialize a video tracker with specified parameters."""
super().__init__() super().__init__()
self.method = method self.method = method
@ -69,6 +70,7 @@ class GMC:
self.initializedFirstFrame = False self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None): def apply(self, raw_frame, detections=None):
"""Apply object detection on a raw frame using specified method."""
if self.method in ['orb', 'sift']: if self.method in ['orb', 'sift']:
return self.applyFeatures(raw_frame, detections) return self.applyFeatures(raw_frame, detections)
elif self.method == 'ecc': elif self.method == 'ecc':
@ -303,6 +305,7 @@ class GMC:
return H return H
def applyFile(self, raw_frame, detections=None): def applyFile(self, raw_frame, detections=None):
"""Return the homography matrix based on the GCPs in the next line of the input GMC file."""
line = self.gmcFile.readline() line = self.gmcFile.readline()
tokens = line.split('\t') tokens = line.split('\t')
H = np.eye(2, 3, dtype=np.float_) H = np.eye(2, 3, dtype=np.float_)

@ -27,6 +27,7 @@ class KalmanFilterXYAH:
""" """
def __init__(self): def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainty weights."""
ndim, dt = 4, 1. ndim, dt = 4, 1.
# Create Kalman filter model matrices. # Create Kalman filter model matrices.
@ -253,6 +254,7 @@ class KalmanFilterXYWH:
""" """
def __init__(self): def __init__(self):
"""Initialize Kalman filter model matrices with motion and observation uncertainties."""
ndim, dt = 4, 1. ndim, dt = 4, 1.
# Create Kalman filter model matrices. # Create Kalman filter model matrices.

@ -18,6 +18,7 @@ except (ImportError, AssertionError, AttributeError):
def merge_matches(m1, m2, shape): def merge_matches(m1, m2, shape):
"""Merge two sets of matches and return matched and unmatched indices."""
O, P, Q = shape O, P, Q = shape
m1 = np.asarray(m1) m1 = np.asarray(m1)
m2 = np.asarray(m2) m2 = np.asarray(m2)
@ -35,6 +36,7 @@ def merge_matches(m1, m2, shape):
def _indices_to_matches(cost_matrix, indices, thresh): def _indices_to_matches(cost_matrix, indices, thresh):
"""_indices_to_matches: Return matched and unmatched indices given a cost matrix, indices, and a threshold."""
matched_cost = cost_matrix[tuple(zip(*indices))] matched_cost = cost_matrix[tuple(zip(*indices))]
matched_mask = (matched_cost <= thresh) matched_mask = (matched_cost <= thresh)
@ -144,6 +146,7 @@ def embedding_distance(tracks, detections, metric='cosine'):
def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False): def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
"""Apply gating to the cost matrix based on predicted tracks and detected objects."""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
gating_dim = 2 if only_position else 4 gating_dim = 2 if only_position else 4
@ -156,6 +159,7 @@ def gate_cost_matrix(kf, cost_matrix, tracks, detections, only_position=False):
def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98): def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda_=0.98):
"""Fuse motion between tracks and detections with gating and Kalman filtering."""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
gating_dim = 2 if only_position else 4 gating_dim = 2 if only_position else 4
@ -169,6 +173,7 @@ def fuse_motion(kf, cost_matrix, tracks, detections, only_position=False, lambda
def fuse_iou(cost_matrix, tracks, detections): def fuse_iou(cost_matrix, tracks, detections):
"""Fuses ReID and IoU similarity matrices to yield a cost matrix for object tracking."""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
reid_sim = 1 - cost_matrix reid_sim = 1 - cost_matrix
@ -181,6 +186,7 @@ def fuse_iou(cost_matrix, tracks, detections):
def fuse_score(cost_matrix, detections): def fuse_score(cost_matrix, detections):
"""Fuses cost matrix with detection scores to produce a single similarity matrix."""
if cost_matrix.size == 0: if cost_matrix.size == 0:
return cost_matrix return cost_matrix
iou_sim = 1 - cost_matrix iou_sim = 1 - cost_matrix

@ -393,6 +393,7 @@ def entrypoint(debug=''):
# Special modes -------------------------------------------------------------------------------------------------------- # Special modes --------------------------------------------------------------------------------------------------------
def copy_default_cfg(): def copy_default_cfg():
"""Copy and create a new default configuration file with '_copy' appended to its name."""
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml') new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace('.yaml', '_copy.yaml')
shutil.copy2(DEFAULT_CFG_PATH, new_file) shutil.copy2(DEFAULT_CFG_PATH, new_file)
LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n' LOGGER.info(f'{DEFAULT_CFG_PATH} copied to {new_file}\n'

@ -26,15 +26,19 @@ class BaseTransform:
pass pass
def apply_image(self, labels): def apply_image(self, labels):
"""Applies image transformation to labels."""
pass pass
def apply_instances(self, labels): def apply_instances(self, labels):
"""Applies transformations to input 'labels' and returns object instances."""
pass pass
def apply_semantic(self, labels): def apply_semantic(self, labels):
"""Applies semantic segmentation to an image."""
pass pass
def __call__(self, labels): def __call__(self, labels):
"""Applies label transformations to an image, instances and semantic masks."""
self.apply_image(labels) self.apply_image(labels)
self.apply_instances(labels) self.apply_instances(labels)
self.apply_semantic(labels) self.apply_semantic(labels)
@ -43,20 +47,25 @@ class BaseTransform:
class Compose: class Compose:
def __init__(self, transforms): def __init__(self, transforms):
"""Initializes the Compose object with a list of transforms."""
self.transforms = transforms self.transforms = transforms
def __call__(self, data): def __call__(self, data):
"""Applies a series of transformations to input data."""
for t in self.transforms: for t in self.transforms:
data = t(data) data = t(data)
return data return data
def append(self, transform): def append(self, transform):
"""Appends a new transform to the existing list of transforms."""
self.transforms.append(transform) self.transforms.append(transform)
def tolist(self): def tolist(self):
"""Converts list of transforms to a standard Python list."""
return self.transforms return self.transforms
def __repr__(self): def __repr__(self):
"""Return string representation of object."""
format_string = f'{self.__class__.__name__}(' format_string = f'{self.__class__.__name__}('
for t in self.transforms: for t in self.transforms:
format_string += '\n' format_string += '\n'
@ -74,6 +83,7 @@ class BaseMixTransform:
self.p = p self.p = p
def __call__(self, labels): def __call__(self, labels):
"""Applies pre-processing transforms and mixup/mosaic transforms to labels data."""
if random.uniform(0, 1) > self.p: if random.uniform(0, 1) > self.p:
return labels return labels
@ -96,9 +106,11 @@ class BaseMixTransform:
return labels return labels
def _mix_transform(self, labels): def _mix_transform(self, labels):
"""Applies MixUp or Mosaic augmentation to the label dictionary."""
raise NotImplementedError raise NotImplementedError
def get_indexes(self): def get_indexes(self):
"""Gets a list of shuffled indexes for mosaic augmentation."""
raise NotImplementedError raise NotImplementedError
@ -111,6 +123,7 @@ class Mosaic(BaseMixTransform):
""" """
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)): def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.' assert 0 <= p <= 1.0, 'The probability should be in range [0, 1]. ' f'got {p}.'
super().__init__(dataset=dataset, p=p) super().__init__(dataset=dataset, p=p)
self.dataset = dataset self.dataset = dataset
@ -118,9 +131,11 @@ class Mosaic(BaseMixTransform):
self.border = border self.border = border
def get_indexes(self): def get_indexes(self):
"""Return a list of 3 random indexes from the dataset."""
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)] return [random.randint(0, len(self.dataset) - 1) for _ in range(3)]
def _mix_transform(self, labels): def _mix_transform(self, labels):
"""Apply mixup transformation to the input image and labels."""
mosaic_labels = [] mosaic_labels = []
assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.' assert labels.get('rect_shape', None) is None, 'rect and mosaic is exclusive.'
assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.' assert len(labels.get('mix_labels', [])) > 0, 'There are no other images for mosaic augment.'
@ -166,6 +181,7 @@ class Mosaic(BaseMixTransform):
return labels return labels
def _cat_labels(self, mosaic_labels): def _cat_labels(self, mosaic_labels):
"""Return labels with mosaic border instances clipped."""
if len(mosaic_labels) == 0: if len(mosaic_labels) == 0:
return {} return {}
cls = [] cls = []
@ -190,6 +206,7 @@ class MixUp(BaseMixTransform):
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def get_indexes(self): def get_indexes(self):
"""Get a random index from the dataset."""
return random.randint(0, len(self.dataset) - 1) return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels): def _mix_transform(self, labels):
@ -400,6 +417,7 @@ class RandomHSV:
self.vgain = vgain self.vgain = vgain
def __call__(self, labels): def __call__(self, labels):
"""Applies random horizontal or vertical flip to an image with a given probability."""
img = labels['img'] img = labels['img']
if self.hgain or self.sgain or self.vgain: if self.hgain or self.sgain or self.vgain:
r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains r = np.random.uniform(-1, 1, 3) * [self.hgain, self.sgain, self.vgain] + 1 # random gains
@ -427,6 +445,7 @@ class RandomFlip:
self.flip_idx = flip_idx self.flip_idx = flip_idx
def __call__(self, labels): def __call__(self, labels):
"""Resize image and padding for detection, instance segmentation, pose."""
img = labels['img'] img = labels['img']
instances = labels.pop('instances') instances = labels.pop('instances')
instances.convert_bbox(format='xywh') instances.convert_bbox(format='xywh')
@ -453,6 +472,7 @@ class LetterBox:
"""Resize image and padding for detection, instance segmentation, pose.""" """Resize image and padding for detection, instance segmentation, pose."""
def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32): def __init__(self, new_shape=(640, 640), auto=False, scaleFill=False, scaleup=True, stride=32):
"""Initialize LetterBox object with specific parameters."""
self.new_shape = new_shape self.new_shape = new_shape
self.auto = auto self.auto = auto
self.scaleFill = scaleFill self.scaleFill = scaleFill
@ -460,6 +480,7 @@ class LetterBox:
self.stride = stride self.stride = stride
def __call__(self, labels=None, image=None): def __call__(self, labels=None, image=None):
"""Return updated labels and image with added border."""
if labels is None: if labels is None:
labels = {} labels = {}
img = labels.get('img') if image is None else image img = labels.get('img') if image is None else image
@ -556,6 +577,7 @@ class CopyPaste:
class Albumentations: class Albumentations:
# YOLOv8 Albumentations class (optional, only used if package is installed) # YOLOv8 Albumentations class (optional, only used if package is installed)
def __init__(self, p=1.0): def __init__(self, p=1.0):
"""Initialize the transform object for YOLO bbox formatted params."""
self.p = p self.p = p
self.transform = None self.transform = None
prefix = colorstr('albumentations: ') prefix = colorstr('albumentations: ')
@ -581,6 +603,7 @@ class Albumentations:
LOGGER.info(f'{prefix}{e}') LOGGER.info(f'{prefix}{e}')
def __call__(self, labels): def __call__(self, labels):
"""Generates object detections and returns a dictionary with detection results."""
im = labels['img'] im = labels['img']
cls = labels['cls'] cls = labels['cls']
if len(cls): if len(cls):
@ -618,6 +641,7 @@ class Format:
self.batch_idx = batch_idx # keep the batch indexes self.batch_idx = batch_idx # keep the batch indexes
def __call__(self, labels): def __call__(self, labels):
"""Return formatted image, classes, bounding boxes & keypoints to be used by 'collate_fn'."""
img = labels.pop('img') img = labels.pop('img')
h, w = img.shape[:2] h, w = img.shape[:2]
cls = labels.pop('cls') cls = labels.pop('cls')
@ -647,6 +671,7 @@ class Format:
return labels return labels
def _format_img(self, img): def _format_img(self, img):
"""Format the image for YOLOv5 from Numpy array to PyTorch tensor."""
if len(img.shape) < 3: if len(img.shape) < 3:
img = np.expand_dims(img, -1) img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1]) img = np.ascontiguousarray(img.transpose(2, 0, 1)[::-1])
@ -668,6 +693,7 @@ class Format:
def v8_transforms(dataset, imgsz, hyp): def v8_transforms(dataset, imgsz, hyp):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose([ pre_transform = Compose([
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]), Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]),
CopyPaste(p=hyp.copy_paste), CopyPaste(p=hyp.copy_paste),
@ -749,6 +775,7 @@ def classify_albumentations(
class ClassifyLetterBox: class ClassifyLetterBox:
# YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) # YOLOv8 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32): def __init__(self, size=(640, 640), auto=False, stride=32):
"""Resizes image and crops it to center with max dimensions 'h' and 'w'."""
super().__init__() super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride self.auto = auto # pass max size integer, automatically solve for short side using stride
@ -768,6 +795,7 @@ class ClassifyLetterBox:
class CenterCrop: class CenterCrop:
# YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) # YOLOv8 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640): def __init__(self, size=640):
"""Converts an image from numpy array to PyTorch tensor."""
super().__init__() super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size self.h, self.w = (size, size) if isinstance(size, int) else size
@ -781,6 +809,7 @@ class CenterCrop:
class ToTensor: class ToTensor:
# YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) # YOLOv8 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False): def __init__(self, half=False):
"""Initialize YOLOv8 ToTensor object with optional half-precision support."""
super().__init__() super().__init__()
self.half = half self.half = half

@ -170,6 +170,7 @@ class BaseDataset(Dataset):
np.save(f.as_posix(), cv2.imread(self.im_files[i])) np.save(f.as_posix(), cv2.imread(self.im_files[i]))
def set_rectangle(self): def set_rectangle(self):
"""Sets the shape of bounding boxes for YOLO detections as rectangles."""
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
nb = bi[-1] + 1 # number of batches nb = bi[-1] + 1 # number of batches
@ -194,9 +195,11 @@ class BaseDataset(Dataset):
self.batch = bi # batch index of image self.batch = bi # batch index of image
def __getitem__(self, index): def __getitem__(self, index):
"""Returns transformed label information for given index."""
return self.transforms(self.get_label_info(index)) return self.transforms(self.get_label_info(index))
def get_label_info(self, index): def get_label_info(self, index):
"""Get and return label information from the dataset."""
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948 label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
label.pop('shape', None) # shape is for rect, remove it label.pop('shape', None) # shape is for rect, remove it
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index) label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
@ -208,6 +211,7 @@ class BaseDataset(Dataset):
return label return label
def __len__(self): def __len__(self):
"""Returns the length of the labels list for the dataset."""
return len(self.labels) return len(self.labels)
def update_labels_info(self, label): def update_labels_info(self, label):

@ -24,14 +24,17 @@ class InfiniteDataLoader(dataloader.DataLoader):
"""Dataloader that reuses workers. Uses same syntax as vanilla DataLoader.""" """Dataloader that reuses workers. Uses same syntax as vanilla DataLoader."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Dataloader that infinitely recycles workers, inherits from DataLoader."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
"""Returns the length of the batch sampler's sampler."""
return len(self.batch_sampler.sampler) return len(self.batch_sampler.sampler)
def __iter__(self): def __iter__(self):
"""Creates a sampler that repeats indefinitely."""
for _ in range(len(self)): for _ in range(len(self)):
yield next(self.iterator) yield next(self.iterator)
@ -45,9 +48,11 @@ class _RepeatSampler:
""" """
def __init__(self, sampler): def __init__(self, sampler):
"""Initializes an object that repeats a given sampler indefinitely."""
self.sampler = sampler self.sampler = sampler
def __iter__(self): def __iter__(self):
"""Iterates over the 'sampler' and yields its contents."""
while True: while True:
yield from iter(self.sampler) yield from iter(self.sampler)
@ -60,6 +65,7 @@ def seed_worker(worker_id): # noqa
def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, rank=-1, mode='train'): def build_dataloader(cfg, batch, img_path, data_info, stride=32, rect=False, rank=-1, mode='train'):
"""Return an InfiniteDataLoader or DataLoader for training or validation set."""
assert mode in ['train', 'val'] assert mode in ['train', 'val']
shuffle = mode == 'train' shuffle = mode == 'train'
if cfg.rect and shuffle: if cfg.rect and shuffle:
@ -134,6 +140,7 @@ def build_classification_dataloader(path,
def check_source(source): def check_source(source):
"""Check source type and return corresponding flag values."""
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
if isinstance(source, (str, int, Path)): # int for local usb camera if isinstance(source, (str, int, Path)): # int for local usb camera
source = str(source) source = str(source)

@ -32,6 +32,7 @@ class SourceTypes:
class LoadStreams: class LoadStreams:
# YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` # YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, sources='file.streams', imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream' self.mode = 'stream'
self.imgsz = imgsz self.imgsz = imgsz
@ -97,10 +98,12 @@ class LoadStreams:
time.sleep(0.0) # wait time time.sleep(0.0) # wait time
def __iter__(self): def __iter__(self):
"""Iterates through YOLO image feed and re-opens unresponsive streams."""
self.count = -1 self.count = -1
return self return self
def __next__(self): def __next__(self):
"""Returns source paths, transformed and original images for processing YOLOv5."""
self.count += 1 self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows() cv2.destroyAllWindows()
@ -117,6 +120,7 @@ class LoadStreams:
return self.sources, im, im0, None, '' return self.sources, im, im0, None, ''
def __len__(self): def __len__(self):
"""Return the length of the sources object."""
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
@ -153,6 +157,7 @@ class LoadScreenshots:
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height} self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
def __iter__(self): def __iter__(self):
"""Returns an iterator of the object."""
return self return self
def __next__(self): def __next__(self):
@ -173,6 +178,7 @@ class LoadScreenshots:
class LoadImages: class LoadImages:
# YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4` # YOLOv8 image/video dataloader, i.e. `yolo predict source=image.jpg/vid.mp4`
def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, path, imgsz=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initialize the Dataloader and raise FileNotFoundError if file not found."""
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit() path = Path(path).read_text().rsplit()
files = [] files = []
@ -211,10 +217,12 @@ class LoadImages:
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}') f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}')
def __iter__(self): def __iter__(self):
"""Returns an iterator object for VideoStream or ImageFolder."""
self.count = 0 self.count = 0
return self return self
def __next__(self): def __next__(self):
"""Return next image, path and metadata from dataset."""
if self.count == self.nf: if self.count == self.nf:
raise StopIteration raise StopIteration
path = self.files[self.count] path = self.files[self.count]
@ -276,12 +284,14 @@ class LoadImages:
return im return im
def __len__(self): def __len__(self):
"""Returns the number of files in the object."""
return self.nf # number of files return self.nf # number of files
class LoadPilAndNumpy: class LoadPilAndNumpy:
def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None): def __init__(self, im0, imgsz=640, stride=32, auto=True, transforms=None):
"""Initialize PIL and Numpy Dataloader."""
if not isinstance(im0, list): if not isinstance(im0, list):
im0 = [im0] im0 = [im0]
self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)] self.paths = [getattr(im, 'filename', f'image{i}.jpg') for i, im in enumerate(im0)]
@ -296,6 +306,7 @@ class LoadPilAndNumpy:
@staticmethod @staticmethod
def _single_check(im): def _single_check(im):
"""Validate and format an image to numpy array."""
assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}' assert isinstance(im, (Image.Image, np.ndarray)), f'Expected PIL/np.ndarray image type, but got {type(im)}'
if isinstance(im, Image.Image): if isinstance(im, Image.Image):
if im.mode != 'RGB': if im.mode != 'RGB':
@ -305,6 +316,7 @@ class LoadPilAndNumpy:
return im return im
def _single_preprocess(self, im, auto): def _single_preprocess(self, im, auto):
"""Preprocesses a single image for inference."""
if self.transforms: if self.transforms:
im = self.transforms(im) # transforms im = self.transforms(im) # transforms
else: else:
@ -314,9 +326,11 @@ class LoadPilAndNumpy:
return im return im
def __len__(self): def __len__(self):
"""Returns the length of the 'im0' attribute."""
return len(self.im0) return len(self.im0)
def __next__(self): def __next__(self):
"""Returns batch paths, images, processed images, None, ''."""
if self.count == 1: # loop only once as it's batch inference if self.count == 1: # loop only once as it's batch inference
raise StopIteration raise StopIteration
auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto auto = all(x.shape == self.im0[0].shape for x in self.im0) and self.auto
@ -326,6 +340,7 @@ class LoadPilAndNumpy:
return self.paths, im, self.im0, None, '' return self.paths, im, self.im0, None, ''
def __iter__(self): def __iter__(self):
"""Enables iteration for class LoadPilAndNumpy."""
self.count = 0 self.count = 0
return self return self
@ -338,16 +353,19 @@ class LoadTensor:
self.mode = 'image' self.mode = 'image'
def __iter__(self): def __iter__(self):
"""Returns an iterator object."""
self.count = 0 self.count = 0
return self return self
def __next__(self): def __next__(self):
"""Return next item in the iterator."""
if self.count == 1: if self.count == 1:
raise StopIteration raise StopIteration
self.count += 1 self.count += 1
return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, '' return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, ''
def __len__(self): def __len__(self):
"""Returns the batch size."""
return self.bs return self.bs

@ -24,6 +24,7 @@ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
class Albumentations: class Albumentations:
# YOLOv5 Albumentations class (optional, only used if package is installed) # YOLOv5 Albumentations class (optional, only used if package is installed)
def __init__(self, size=640): def __init__(self, size=640):
"""Instantiate object with image augmentations for YOLOv5."""
self.transform = None self.transform = None
prefix = colorstr('albumentations: ') prefix = colorstr('albumentations: ')
try: try:
@ -48,6 +49,7 @@ class Albumentations:
LOGGER.info(f'{prefix}{e}') LOGGER.info(f'{prefix}{e}')
def __call__(self, im, labels, p=1.0): def __call__(self, im, labels, p=1.0):
"""Transforms input image and labels with probability 'p'."""
if self.transform and random.random() < p: if self.transform and random.random() < p:
new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
@ -111,7 +113,7 @@ def replicate(im, labels):
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints """Resize and pad image while meeting stride-multiple constraints."""
shape = im.shape[:2] # current shape [height, width] shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int): if isinstance(new_shape, int):
new_shape = (new_shape, new_shape) new_shape = (new_shape, new_shape)
@ -359,6 +361,7 @@ def classify_transforms(size=224):
class LetterBox: class LetterBox:
# YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, size=(640, 640), auto=False, stride=32): def __init__(self, size=(640, 640), auto=False, stride=32):
"""Resizes and crops an image to a specified size for YOLOv5 preprocessing."""
super().__init__() super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size self.h, self.w = (size, size) if isinstance(size, int) else size
self.auto = auto # pass max size integer, automatically solve for short side using stride self.auto = auto # pass max size integer, automatically solve for short side using stride
@ -378,6 +381,7 @@ class LetterBox:
class CenterCrop: class CenterCrop:
# YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()]) # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
def __init__(self, size=640): def __init__(self, size=640):
"""Converts input image into tensor for YOLOv5 processing."""
super().__init__() super().__init__()
self.h, self.w = (size, size) if isinstance(size, int) else size self.h, self.w = (size, size) if isinstance(size, int) else size
@ -391,6 +395,7 @@ class CenterCrop:
class ToTensor: class ToTensor:
# YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()]) # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
def __init__(self, half=False): def __init__(self, half=False):
"""Initialize ToTensor class for YOLOv5 image preprocessing."""
super().__init__() super().__init__()
self.half = half self.half = half

@ -162,14 +162,17 @@ class InfiniteDataLoader(dataloader.DataLoader):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Dataloader that reuses workers for same syntax as vanilla DataLoader."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__() self.iterator = super().__iter__()
def __len__(self): def __len__(self):
"""Returns the length of batch_sampler's sampler."""
return len(self.batch_sampler.sampler) return len(self.batch_sampler.sampler)
def __iter__(self): def __iter__(self):
"""Creates a sampler that infinitely repeats."""
for _ in range(len(self)): for _ in range(len(self)):
yield next(self.iterator) yield next(self.iterator)
@ -182,9 +185,11 @@ class _RepeatSampler:
""" """
def __init__(self, sampler): def __init__(self, sampler):
"""Sampler that repeats dataset samples infinitely."""
self.sampler = sampler self.sampler = sampler
def __iter__(self): def __iter__(self):
"""Infinite loop iterating over a given sampler."""
while True: while True:
yield from iter(self.sampler) yield from iter(self.sampler)
@ -221,6 +226,7 @@ class LoadScreenshots:
self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height} self.monitor = {'left': self.left, 'top': self.top, 'width': self.width, 'height': self.height}
def __iter__(self): def __iter__(self):
"""Iterates over objects with the same structure as the monitor attribute."""
return self return self
def __next__(self): def __next__(self):
@ -241,6 +247,7 @@ class LoadScreenshots:
class LoadImages: class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4` # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initialize instance variables and check for valid input."""
if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line if isinstance(path, str) and Path(path).suffix == '.txt': # *.txt file with img/vid/dir on each line
path = Path(path).read_text().rsplit() path = Path(path).read_text().rsplit()
files = [] files = []
@ -276,10 +283,12 @@ class LoadImages:
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}' f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
def __iter__(self): def __iter__(self):
"""Returns an iterator object for iterating over images or videos found in a directory."""
self.count = 0 self.count = 0
return self return self
def __next__(self): def __next__(self):
"""Iterator's next item, performs transformation on image and returns path, transformed image, original image, capture and size."""
if self.count == self.nf: if self.count == self.nf:
raise StopIteration raise StopIteration
path = self.files[self.count] path = self.files[self.count]
@ -338,12 +347,14 @@ class LoadImages:
return im return im
def __len__(self): def __len__(self):
"""Returns the number of files in the class instance."""
return self.nf # number of files return self.nf # number of files
class LoadStreams: class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams` # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1): def __init__(self, sources='file.streams', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
"""Initialize YOLO detector with optional transforms and check input shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.mode = 'stream' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
@ -404,10 +415,12 @@ class LoadStreams:
time.sleep(0.0) # wait time time.sleep(0.0) # wait time
def __iter__(self): def __iter__(self):
"""Iterator that returns the class instance."""
self.count = -1 self.count = -1
return self return self
def __next__(self): def __next__(self):
"""Return a tuple containing transformed and resized image data."""
self.count += 1 self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows() cv2.destroyAllWindows()
@ -424,6 +437,7 @@ class LoadStreams:
return self.sources, im, im0, None, '' return self.sources, im, im0, None, ''
def __len__(self): def __len__(self):
"""Returns the number of sources as the length of the object."""
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
@ -607,6 +621,7 @@ class LoadImagesAndLabels(Dataset):
return cache return cache
def cache_labels(self, path=Path('./labels.cache'), prefix=''): def cache_labels(self, path=Path('./labels.cache'), prefix=''):
"""Cache labels and save as numpy file for next time."""
# Cache dataset labels, check images and read shapes # Cache dataset labels, check images and read shapes
if path.exists(): if path.exists():
path.unlink() # remove *.cache file if exists path.unlink() # remove *.cache file if exists
@ -646,9 +661,11 @@ class LoadImagesAndLabels(Dataset):
return x return x
def __len__(self): def __len__(self):
"""Returns the length of 'im_files' attribute."""
return len(self.im_files) return len(self.im_files)
def __getitem__(self, index): def __getitem__(self, index):
"""Get a sample and its corresponding label, filename and shape from the dataset."""
index = self.indices[index] # linear, shuffled, or image_weights index = self.indices[index] # linear, shuffled, or image_weights
hyp = self.hyp hyp = self.hyp
@ -1039,6 +1056,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
""" """
def __init__(self, root, augment, imgsz, cache=False): def __init__(self, root, augment, imgsz, cache=False):
"""Initialize YOLO dataset with root, augmentation, image size, and cache parameters."""
super().__init__(root=root) super().__init__(root=root)
self.torch_transforms = classify_transforms(imgsz) self.torch_transforms = classify_transforms(imgsz)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
@ -1047,6 +1065,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
def __getitem__(self, i): def __getitem__(self, i):
"""Retrieves data items of 'dataset' via indices & creates InfiniteDataLoader."""
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram and im is None: if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f) im = self.samples[i][3] = cv2.imread(f)

@ -127,6 +127,7 @@ class YOLODataset(BaseDataset):
return x return x
def get_labels(self): def get_labels(self):
"""Returns dictionary of labels for YOLO training."""
self.label_files = img2label_paths(self.im_files) self.label_files = img2label_paths(self.im_files)
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') cache_path = Path(self.label_files[0]).parent.with_suffix('.cache')
try: try:
@ -170,6 +171,7 @@ class YOLODataset(BaseDataset):
# TODO: use hyp config to set all these augmentations # TODO: use hyp config to set all these augmentations
def build_transforms(self, hyp=None): def build_transforms(self, hyp=None):
"""Builds and appends transforms to the list."""
if self.augment: if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
@ -187,6 +189,7 @@ class YOLODataset(BaseDataset):
return transforms return transforms
def close_mosaic(self, hyp): def close_mosaic(self, hyp):
"""Sets mosaic, copy_paste and mixup options to 0.0 and builds transformations."""
hyp.mosaic = 0.0 # set mosaic ratio=0.0 hyp.mosaic = 0.0 # set mosaic ratio=0.0
hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic hyp.copy_paste = 0.0 # keep the same behavior as previous v8 close-mosaic
hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic hyp.mixup = 0.0 # keep the same behavior as previous v8 close-mosaic
@ -206,6 +209,7 @@ class YOLODataset(BaseDataset):
@staticmethod @staticmethod
def collate_fn(batch): def collate_fn(batch):
"""Collates data samples into batches."""
new_batch = {} new_batch = {}
keys = batch[0].keys() keys = batch[0].keys()
values = list(zip(*[list(b.values()) for b in batch])) values = list(zip(*[list(b.values()) for b in batch]))
@ -234,6 +238,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
""" """
def __init__(self, root, augment, imgsz, cache=False): def __init__(self, root, augment, imgsz, cache=False):
"""Initialize YOLO object with root, image size, augmentations, and cache settings"""
super().__init__(root=root) super().__init__(root=root)
self.torch_transforms = classify_transforms(imgsz) self.torch_transforms = classify_transforms(imgsz)
self.album_transforms = classify_albumentations(augment, imgsz) if augment else None self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
@ -242,6 +247,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
def __getitem__(self, i): def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram and im is None: if self.cache_ram and im is None:
im = self.samples[i][3] = cv2.imread(f) im = self.samples[i][3] = cv2.imread(f)
@ -265,4 +271,5 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
class SemanticDataset(BaseDataset): class SemanticDataset(BaseDataset):
def __init__(self): def __init__(self):
"""Initialize a SemanticDataset object."""
pass pass

@ -359,6 +359,7 @@ class HUBDatasetStats():
return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
def _hub_ops(self, f): def _hub_ops(self, f):
"""Saves a compressed image for HUB previews."""
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
def get_json(self, save=False, verbose=False): def get_json(self, save=False, verbose=False):

@ -105,6 +105,7 @@ def try_export(inner_func):
inner_args = get_default_args(inner_func) inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs): def outer_func(*args, **kwargs):
"""Export a model."""
prefix = inner_args['prefix'] prefix = inner_args['prefix']
try: try:
with Profile() as dt: with Profile() as dt:
@ -118,24 +119,6 @@ def try_export(inner_func):
return outer_func return outer_func
class iOSDetectModel(torch.nn.Module):
"""Wrap an Ultralytics YOLO model for iOS export."""
def __init__(self, model, im):
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
class Exporter: class Exporter:
""" """
A class for exporting a model. A class for exporting a model.
@ -160,6 +143,7 @@ class Exporter:
@smart_inference_mode() @smart_inference_mode()
def __call__(self, model=None): def __call__(self, model=None):
"""Returns list of exported files/dirs after running callbacks."""
self.run_callbacks('on_export_start') self.run_callbacks('on_export_start')
t = time.time() t = time.time()
format = self.args.format.lower() # to lowercase format = self.args.format.lower() # to lowercase
@ -703,7 +687,7 @@ class Exporter:
tmp_file.unlink() tmp_file.unlink()
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')): def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')):
# YOLOv8 CoreML pipeline """YOLOv8 CoreML pipeline."""
import coremltools as ct # noqa import coremltools as ct # noqa
LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...') LOGGER.info(f'{prefix} starting pipeline with coremltools {ct.__version__}...')
@ -826,11 +810,33 @@ class Exporter:
self.callbacks[event].append(callback) self.callbacks[event].append(callback)
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Execute all callbacks for a given event."""
for callback in self.callbacks.get(event, []): for callback in self.callbacks.get(event, []):
callback(self) callback(self)
class iOSDetectModel(torch.nn.Module):
"""Wrap an Ultralytics YOLO model for iOS export."""
def __init__(self, model, im):
"""Initialize the iOSDetectModel class with a YOLO model and example image."""
super().__init__()
b, c, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
"""Normalize predictions of object detection model with input size-dependent factors."""
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
def export(cfg=DEFAULT_CFG): def export(cfg=DEFAULT_CFG):
"""Export a YOLOv model to a specific format."""
cfg.model = cfg.model or 'yolov8n.yaml' cfg.model = cfg.model or 'yolov8n.yaml'
cfg.format = cfg.format or 'torchscript' cfg.format = cfg.format or 'torchscript'

@ -107,14 +107,17 @@ class YOLO:
self._load(model, task) self._load(model, task)
def __call__(self, source=None, stream=False, **kwargs): def __call__(self, source=None, stream=False, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs) return self.predict(source, stream, **kwargs)
def __getattr__(self, attr): def __getattr__(self, attr):
"""Raises error if object has no requested attribute."""
name = self.__class__.__name__ name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@staticmethod @staticmethod
def is_hub_model(model): def is_hub_model(model):
"""Check if the provided model is a HUB model."""
return any(( return any((
model.startswith('https://hub.ultra'), # i.e. https://hub.ultralytics.com/models/MODEL_ID model.startswith('https://hub.ultra'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
@ -209,6 +212,7 @@ class YOLO:
self.model.info(verbose=verbose) self.model.info(verbose=verbose)
def fuse(self): def fuse(self):
"""Fuse PyTorch Conv2d and BatchNorm2d layers."""
self._check_is_pytorch_model() self._check_is_pytorch_model()
self.model.fuse() self.model.fuse()
@ -493,9 +497,11 @@ class YOLO:
@staticmethod @staticmethod
def _reset_ckpt_args(args): def _reset_ckpt_args(args):
"""Reset arguments when loading a PyTorch model."""
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
return {k: v for k, v in args.items() if k in include} return {k: v for k, v in args.items() if k in include}
def _reset_callbacks(self): def _reset_callbacks(self):
"""Reset all registered callbacks."""
for event in callbacks.default_callbacks.keys(): for event in callbacks.default_callbacks.keys():
self.callbacks[event] = [callbacks.default_callbacks[event][0]] self.callbacks[event] = [callbacks.default_callbacks[event][0]]

@ -107,9 +107,11 @@ class BasePredictor:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
def preprocess(self, img): def preprocess(self, img):
"""Prepares input image before inference."""
pass pass
def write_results(self, idx, results, batch): def write_results(self, idx, results, batch):
"""Write inference results to a file or directory."""
p, im, _ = batch p, im, _ = batch
log_string = '' log_string = ''
if len(im.shape) == 3: if len(im.shape) == 3:
@ -143,9 +145,11 @@ class BasePredictor:
return log_string return log_string
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_img):
"""Post-processes predictions for an image and returns them."""
return preds return preds
def __call__(self, source=None, model=None, stream=False): def __call__(self, source=None, model=None, stream=False):
"""Performs inference on an image or stream."""
self.stream = stream self.stream = stream
if stream: if stream:
return self.stream_inference(source, model) return self.stream_inference(source, model)
@ -159,6 +163,7 @@ class BasePredictor:
pass pass
def setup_source(self, source): def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
if self.args.task == 'classify': if self.args.task == 'classify':
transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0])) transforms = getattr(self.model.model, 'transforms', classify_transforms(self.imgsz[0]))
@ -179,6 +184,7 @@ class BasePredictor:
@smart_inference_mode() @smart_inference_mode()
def stream_inference(self, source=None, model=None): def stream_inference(self, source=None, model=None):
"""Streams real-time inference on camera feed and saves results to file."""
if self.args.verbose: if self.args.verbose:
LOGGER.info('') LOGGER.info('')
@ -264,6 +270,7 @@ class BasePredictor:
self.run_callbacks('on_predict_end') self.run_callbacks('on_predict_end')
def setup_model(self, model, verbose=True): def setup_model(self, model, verbose=True):
"""Initialize YOLO model with given parameters and set it to evaluation mode."""
device = select_device(self.args.device, verbose=verbose) device = select_device(self.args.device, verbose=verbose)
model = model or self.args.model model = model or self.args.model
self.args.half &= device.type != 'cpu' # half precision only supported on CUDA self.args.half &= device.type != 'cpu' # half precision only supported on CUDA
@ -278,6 +285,7 @@ class BasePredictor:
self.model.eval() self.model.eval()
def show(self, p): def show(self, p):
"""Display an image in a window using OpenCV imshow()."""
im0 = self.plotted_img im0 = self.plotted_img
if platform.system() == 'Linux' and p not in self.windows: if platform.system() == 'Linux' and p not in self.windows:
self.windows.append(p) self.windows.append(p)
@ -287,6 +295,7 @@ class BasePredictor:
cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond cv2.waitKey(500 if self.batch[4].startswith('image') else 1) # 1 millisecond
def save_preds(self, vid_cap, idx, save_path): def save_preds(self, vid_cap, idx, save_path):
"""Save video predictions as mp4 at specified path."""
im0 = self.plotted_img im0 = self.plotted_img
# Save imgs # Save imgs
if self.dataset.mode == 'image': if self.dataset.mode == 'image':
@ -307,6 +316,7 @@ class BasePredictor:
self.vid_writer[idx].write(im0) self.vid_writer[idx].write(im0)
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Runs all registered callbacks for a specific event."""
for callback in self.callbacks.get(event, []): for callback in self.callbacks.get(event, []):
callback(self) callback(self)

@ -19,42 +19,41 @@ from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
class BaseTensor(SimpleClass): class BaseTensor(SimpleClass):
""" """
Base tensor class with additional methods for easy manipulation and device handling.
Attributes:
data (torch.Tensor): Base tensor.
orig_shape (tuple): Original image size, in the format (height, width).
Methods:
cpu(): Returns a copy of the tensor on CPU memory.
numpy(): Returns a copy of the tensor as a numpy array.
cuda(): Returns a copy of the tensor on GPU memory.
to(): Returns a copy of the tensor with the specified device and dtype.
""" """
def __init__(self, data, orig_shape) -> None: def __init__(self, data, orig_shape) -> None:
"""Initialize BaseTensor with data and original shape."""
self.data = data self.data = data
self.orig_shape = orig_shape self.orig_shape = orig_shape
@property @property
def shape(self): def shape(self):
"""Return the shape of the data tensor."""
return self.data.shape return self.data.shape
def cpu(self): def cpu(self):
"""Return a copy of the tensor on CPU memory."""
return self.__class__(self.data.cpu(), self.orig_shape) return self.__class__(self.data.cpu(), self.orig_shape)
def numpy(self): def numpy(self):
"""Return a copy of the tensor as a numpy array."""
return self.__class__(self.data.numpy(), self.orig_shape) return self.__class__(self.data.numpy(), self.orig_shape)
def cuda(self): def cuda(self):
"""Return a copy of the tensor on GPU memory."""
return self.__class__(self.data.cuda(), self.orig_shape) return self.__class__(self.data.cuda(), self.orig_shape)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
"""Return a copy of the tensor with the specified device and dtype."""
return self.__class__(self.data.to(*args, **kwargs), self.orig_shape) return self.__class__(self.data.to(*args, **kwargs), self.orig_shape)
def __len__(self): # override len(results) def __len__(self): # override len(results)
"""Return the length of the data tensor."""
return len(self.data) return len(self.data)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Return a BaseTensor with the specified index of the data tensor."""
return self.__class__(self.data[idx], self.orig_shape) return self.__class__(self.data[idx], self.orig_shape)
@ -83,10 +82,10 @@ class Results(SimpleClass):
keypoints (List[List[float]], optional): A list of detected keypoints for each object. keypoints (List[List[float]], optional): A list of detected keypoints for each object.
speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image. speed (dict): A dictionary of preprocess, inference and postprocess speeds in milliseconds per image.
_keys (tuple): A tuple of attribute names for non-empty attributes. _keys (tuple): A tuple of attribute names for non-empty attributes.
""" """
def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None: def __init__(self, orig_img, path, names, boxes=None, masks=None, probs=None, keypoints=None) -> None:
"""Initialize the Results class."""
self.orig_img = orig_img self.orig_img = orig_img
self.orig_shape = orig_img.shape[:2] self.orig_shape = orig_img.shape[:2]
self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes self.boxes = Boxes(boxes, self.orig_shape) if boxes is not None else None # native size boxes
@ -99,16 +98,19 @@ class Results(SimpleClass):
self._keys = ('boxes', 'masks', 'probs', 'keypoints') self._keys = ('boxes', 'masks', 'probs', 'keypoints')
def pandas(self): def pandas(self):
"""Convert the results to a pandas DataFrame."""
pass pass
# TODO masks.pandas + boxes.pandas + cls.pandas # TODO masks.pandas + boxes.pandas + cls.pandas
def __getitem__(self, idx): def __getitem__(self, idx):
"""Return a Results object for the specified index."""
r = self.new() r = self.new()
for k in self.keys: for k in self.keys:
setattr(r, k, getattr(self, k)[idx]) setattr(r, k, getattr(self, k)[idx])
return r return r
def update(self, boxes=None, masks=None, probs=None): def update(self, boxes=None, masks=None, probs=None):
"""Update the boxes, masks, and probs attributes of the Results object."""
if boxes is not None: if boxes is not None:
self.boxes = Boxes(boxes, self.orig_shape) self.boxes = Boxes(boxes, self.orig_shape)
if masks is not None: if masks is not None:
@ -117,38 +119,45 @@ class Results(SimpleClass):
self.probs = probs self.probs = probs
def cpu(self): def cpu(self):
"""Return a copy of the Results object with all tensors on CPU memory."""
r = self.new() r = self.new()
for k in self.keys: for k in self.keys:
setattr(r, k, getattr(self, k).cpu()) setattr(r, k, getattr(self, k).cpu())
return r return r
def numpy(self): def numpy(self):
"""Return a copy of the Results object with all tensors as numpy arrays."""
r = self.new() r = self.new()
for k in self.keys: for k in self.keys:
setattr(r, k, getattr(self, k).numpy()) setattr(r, k, getattr(self, k).numpy())
return r return r
def cuda(self): def cuda(self):
"""Return a copy of the Results object with all tensors on GPU memory."""
r = self.new() r = self.new()
for k in self.keys: for k in self.keys:
setattr(r, k, getattr(self, k).cuda()) setattr(r, k, getattr(self, k).cuda())
return r return r
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
"""Return a copy of the Results object with tensors on the specified device and dtype."""
r = self.new() r = self.new()
for k in self.keys: for k in self.keys:
setattr(r, k, getattr(self, k).to(*args, **kwargs)) setattr(r, k, getattr(self, k).to(*args, **kwargs))
return r return r
def __len__(self): def __len__(self):
"""Return the number of detections in the Results object."""
for k in self.keys: for k in self.keys:
return len(getattr(self, k)) return len(getattr(self, k))
def new(self): def new(self):
"""Return a new Results object with the same image, path, and names."""
return Results(orig_img=self.orig_img, path=self.path, names=self.names) return Results(orig_img=self.orig_img, path=self.path, names=self.names)
@property @property
def keys(self): def keys(self):
"""Return a list of non-empty attribute names."""
return [k for k in self._keys if getattr(self, k) is not None] return [k for k in self._keys if getattr(self, k) is not None]
def plot( def plot(
@ -250,7 +259,8 @@ class Results(SimpleClass):
return log_string return log_string
def save_txt(self, txt_file, save_conf=False): def save_txt(self, txt_file, save_conf=False):
"""Save predictions into txt file. """
Save predictions into txt file.
Args: Args:
txt_file (str): txt file path. txt_file (str): txt file path.
@ -285,7 +295,8 @@ class Results(SimpleClass):
f.write(text + '\n') f.write(text + '\n')
def save_crop(self, save_dir, file_name=Path('im.jpg')): def save_crop(self, save_dir, file_name=Path('im.jpg')):
"""Save cropped predictions to `save_dir/cls/file_name.jpg`. """
Save cropped predictions to `save_dir/cls/file_name.jpg`.
Args: Args:
save_dir (str | pathlib.Path): Save path. save_dir (str | pathlib.Path): Save path.
@ -338,6 +349,7 @@ class Boxes(BaseTensor):
""" """
def __init__(self, boxes, orig_shape) -> None: def __init__(self, boxes, orig_shape) -> None:
"""Initialize the Boxes class."""
if boxes.ndim == 1: if boxes.ndim == 1:
boxes = boxes[None, :] boxes = boxes[None, :]
n = boxes.shape[-1] n = boxes.shape[-1]
@ -349,40 +361,49 @@ class Boxes(BaseTensor):
@property @property
def xyxy(self): def xyxy(self):
"""Return the boxes in xyxy format."""
return self.data[:, :4] return self.data[:, :4]
@property @property
def conf(self): def conf(self):
"""Return the confidence values of the boxes."""
return self.data[:, -2] return self.data[:, -2]
@property @property
def cls(self): def cls(self):
"""Return the class values of the boxes."""
return self.data[:, -1] return self.data[:, -1]
@property @property
def id(self): def id(self):
"""Return the track IDs of the boxes (if available)."""
return self.data[:, -3] if self.is_track else None return self.data[:, -3] if self.is_track else None
@property @property
@lru_cache(maxsize=2) # maxsize 1 should suffice @lru_cache(maxsize=2) # maxsize 1 should suffice
def xywh(self): def xywh(self):
"""Return the boxes in xywh format."""
return ops.xyxy2xywh(self.xyxy) return ops.xyxy2xywh(self.xyxy)
@property @property
@lru_cache(maxsize=2) @lru_cache(maxsize=2)
def xyxyn(self): def xyxyn(self):
"""Return the boxes in xyxy format normalized by original image size."""
return self.xyxy / self.orig_shape[[1, 0, 1, 0]] return self.xyxy / self.orig_shape[[1, 0, 1, 0]]
@property @property
@lru_cache(maxsize=2) @lru_cache(maxsize=2)
def xywhn(self): def xywhn(self):
"""Return the boxes in xywh format normalized by original image size."""
return self.xywh / self.orig_shape[[1, 0, 1, 0]] return self.xywh / self.orig_shape[[1, 0, 1, 0]]
def pandas(self): def pandas(self):
"""Convert the object to a pandas DataFrame (not yet implemented)."""
LOGGER.info('results.pandas() method not yet implemented') LOGGER.info('results.pandas() method not yet implemented')
@property @property
def boxes(self): def boxes(self):
"""Return the raw bboxes tensor (deprecated)."""
LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.") LOGGER.warning("WARNING ⚠️ 'Boxes.boxes' is deprecated. Use 'Boxes.data' instead.")
return self.data return self.data
@ -411,6 +432,7 @@ class Masks(BaseTensor):
""" """
def __init__(self, masks, orig_shape) -> None: def __init__(self, masks, orig_shape) -> None:
"""Initialize the Masks class."""
if masks.ndim == 2: if masks.ndim == 2:
masks = masks[None, :] masks = masks[None, :]
super().__init__(masks, orig_shape) super().__init__(masks, orig_shape)
@ -418,7 +440,7 @@ class Masks(BaseTensor):
@property @property
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def segments(self): def segments(self):
"""Segments-deprecated (normalized).""" """Return segments (deprecated; normalized)."""
LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and " LOGGER.warning("WARNING ⚠️ 'Masks.segments' is deprecated. Use 'Masks.xyn' for segments (normalized) and "
"'Masks.xy' for segments (pixels) instead.") "'Masks.xy' for segments (pixels) instead.")
return self.xyn return self.xyn
@ -426,7 +448,7 @@ class Masks(BaseTensor):
@property @property
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def xyn(self): def xyn(self):
"""Segments (normalized).""" """Return segments (normalized)."""
return [ return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True) ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=True)
for x in ops.masks2segments(self.data)] for x in ops.masks2segments(self.data)]
@ -434,12 +456,13 @@ class Masks(BaseTensor):
@property @property
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def xy(self): def xy(self):
"""Segments (pixels).""" """Return segments (pixels)."""
return [ return [
ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False) ops.scale_coords(self.data.shape[1:], x, self.orig_shape, normalize=False)
for x in ops.masks2segments(self.data)] for x in ops.masks2segments(self.data)]
@property @property
def masks(self): def masks(self):
"""Return the raw masks tensor (deprecated)."""
LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.") LOGGER.warning("WARNING ⚠️ 'Masks.masks' is deprecated. Use 'Masks.data' instead.")
return self.data return self.data

@ -159,6 +159,7 @@ class BaseTrainer:
self.callbacks[event] = [callback] self.callbacks[event] = [callback]
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Run all existing callbacks associated with a particular event."""
for callback in self.callbacks.get(event, []): for callback in self.callbacks.get(event, []):
callback(self) callback(self)
@ -190,6 +191,7 @@ class BaseTrainer:
self._do_train(world_size) self._do_train(world_size)
def _setup_ddp(self, world_size): def _setup_ddp(self, world_size):
"""Initializes and sets the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK) torch.cuda.set_device(RANK)
self.device = torch.device('cuda', RANK) self.device = torch.device('cuda', RANK)
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}') LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
@ -259,6 +261,7 @@ class BaseTrainer:
self.run_callbacks('on_pretrain_routine_end') self.run_callbacks('on_pretrain_routine_end')
def _do_train(self, world_size=1): def _do_train(self, world_size=1):
"""Train completed, evaluate and plot if specified by arguments."""
if world_size > 1: if world_size > 1:
self._setup_ddp(world_size) self._setup_ddp(world_size)
@ -392,6 +395,7 @@ class BaseTrainer:
self.run_callbacks('teardown') self.run_callbacks('teardown')
def save_model(self): def save_model(self):
"""Save model checkpoints based on various conditions."""
ckpt = { ckpt = {
'epoch': self.epoch, 'epoch': self.epoch,
'best_fitness': self.best_fitness, 'best_fitness': self.best_fitness,
@ -436,6 +440,7 @@ class BaseTrainer:
return ckpt return ckpt
def optimizer_step(self): def optimizer_step(self):
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
self.scaler.unscale_(self.optimizer) # unscale gradients self.scaler.unscale_(self.optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
@ -461,9 +466,11 @@ class BaseTrainer:
return metrics, fitness return metrics, fitness
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
"""Get model and raise NotImplementedError for loading cfg files."""
raise NotImplementedError("This task trainer doesn't support loading cfg files") raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self): def get_validator(self):
"""Returns a NotImplementedError when the get_validator function is called."""
raise NotImplementedError('get_validator function not implemented in trainer') raise NotImplementedError('get_validator function not implemented in trainer')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
@ -492,19 +499,24 @@ class BaseTrainer:
self.model.names = self.data['names'] self.model.names = self.data['names']
def build_targets(self, preds, targets): def build_targets(self, preds, targets):
"""Builds target tensors for training YOLO model."""
pass pass
def progress_string(self): def progress_string(self):
"""Returns a string describing training progress."""
return '' return ''
# TODO: may need to put these following functions into callback # TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Plots training samples during YOLOv5 training."""
pass pass
def plot_training_labels(self): def plot_training_labels(self):
"""Plots training labels for YOLO model."""
pass pass
def save_metrics(self, metrics): def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values()) keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
@ -512,9 +524,11 @@ class BaseTrainer:
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n') f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
def plot_metrics(self): def plot_metrics(self):
"""Plot and display metrics visually."""
pass pass
def final_eval(self): def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
for f in self.last, self.best: for f in self.last, self.best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
@ -525,6 +539,7 @@ class BaseTrainer:
self.run_callbacks('on_fit_epoch_end') self.run_callbacks('on_fit_epoch_end')
def check_resume(self): def check_resume(self):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume resume = self.args.resume
if resume: if resume:
try: try:
@ -539,6 +554,7 @@ class BaseTrainer:
self.resume = resume self.resume = resume
def resume_training(self, ckpt): def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None: if ckpt is None:
return return
best_fitness = 0.0 best_fitness = 0.0

@ -195,58 +195,72 @@ class BaseValidator:
return stats return stats
def add_callback(self, event: str, callback): def add_callback(self, event: str, callback):
""" """Appends the given callback."""
Appends the given callback.
"""
self.callbacks[event].append(callback) self.callbacks[event].append(callback)
def run_callbacks(self, event: str): def run_callbacks(self, event: str):
"""Runs all callbacks associated with a specified event."""
for callback in self.callbacks.get(event, []): for callback in self.callbacks.get(event, []):
callback(self) callback(self)
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError('get_dataloader function not implemented for this validator') raise NotImplementedError('get_dataloader function not implemented for this validator')
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses an input batch."""
return batch return batch
def postprocess(self, preds): def postprocess(self, preds):
"""Describes and summarizes the purpose of 'postprocess()' but no details mentioned."""
return preds return preds
def init_metrics(self, model): def init_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""
pass pass
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
"""Updates metrics based on predictions and batch."""
pass pass
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
"""Finalizes and returns all metrics."""
pass pass
def get_stats(self): def get_stats(self):
"""Returns statistics about the model's performance."""
return {} return {}
def check_stats(self, stats): def check_stats(self, stats):
"""Checks statistics."""
pass pass
def print_results(self): def print_results(self):
"""Prints the results of the model's predictions."""
pass pass
def get_desc(self): def get_desc(self):
"""Get description of the YOLO model."""
pass pass
@property @property
def metric_keys(self): def metric_keys(self):
"""Returns the metric keys used in YOLO training/validation."""
return [] return []
# TODO: may need to put these following functions into callback # TODO: may need to put these following functions into callback
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
"""Plots validation samples during training."""
pass pass
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots YOLO model predictions on batch images."""
pass pass
def pred_to_json(self, preds, batch): def pred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""
pass pass
def eval_json(self, stats): def eval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""
pass pass

@ -182,8 +182,10 @@ def plt_settings(rcparams={'font.size': 11}, backend='Agg'):
""" """
def decorator(func): def decorator(func):
"""Decorator to apply temporary rc parameters and backend to a function."""
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""Sets rc parameters and backend, calls the original function, and restores the settings."""
original_backend = plt.get_backend() original_backend = plt.get_backend()
plt.switch_backend(backend) plt.switch_backend(backend)
@ -229,6 +231,7 @@ class EmojiFilter(logging.Filter):
""" """
def filter(self, record): def filter(self, record):
"""Filter logs by emoji unicode characters on windows."""
record.msg = emojis(record.msg) record.msg = emojis(record.msg)
return super().filter(record) return super().filter(record)
@ -573,13 +576,16 @@ class TryExcept(contextlib.ContextDecorator):
"""YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager.""" """YOLOv8 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager."""
def __init__(self, msg='', verbose=True): def __init__(self, msg='', verbose=True):
"""Initialize TryExcept class with optional message and verbosity settings."""
self.msg = msg self.msg = msg
self.verbose = verbose self.verbose = verbose
def __enter__(self): def __enter__(self):
"""Executes when entering TryExcept context, initializes instance."""
pass pass
def __exit__(self, exc_type, value, traceback): def __exit__(self, exc_type, value, traceback):
"""Defines behavior when exiting a 'with' block, prints error message if necessary."""
if self.verbose and value: if self.verbose and value:
print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}")) print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
return True return True
@ -589,6 +595,7 @@ def threaded(func):
"""Multi-threads a target function and returns thread. Usage: @threaded decorator.""" """Multi-threads a target function and returns thread. Usage: @threaded decorator."""
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
"""Multi-threads a given function and returns the thread."""
thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
thread.start() thread.start()
return thread return thread
@ -602,6 +609,7 @@ def set_sentry():
""" """
def before_send(event, hint): def before_send(event, hint):
"""A function executed before sending the event to Sentry."""
if 'exc_info' in hint: if 'exc_info' in hint:
exc_type, exc_value, tb = hint['exc_info'] exc_type, exc_value, tb = hint['exc_info']
if exc_type in (KeyboardInterrupt, FileNotFoundError) \ if exc_type in (KeyboardInterrupt, FileNotFoundError) \
@ -698,6 +706,7 @@ def set_settings(kwargs, file=SETTINGS_YAML):
def deprecation_warn(arg, new_arg, version=None): def deprecation_warn(arg, new_arg, version=None):
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
if not version: if not version:
version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release version = float(__version__[:3]) + 0.2 # deprecate after 2nd major release
LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. " LOGGER.warning(f"WARNING ⚠️ '{arg}' is deprecated and will be removed in 'ultralytics {version}' in the future. "

@ -35,7 +35,30 @@ from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.utils.torch_utils import select_device
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=False): def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
imgsz=160,
half=False,
int8=False,
device='cpu',
hard_fail=False):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
Args:
model (Union[str, Path], optional): Path to the model file or directory. Default is
Path(SETTINGS['weights_dir']) / 'yolov8n.pt'.
imgsz (int, optional): Image size for the benchmark. Default is 160.
half (bool, optional): Use half-precision for the model if True. Default is False.
int8 (bool, optional): Use int8-precision for the model if True. Default is False.
device (str, optional): Device to run the benchmark on, either 'cpu' or 'cuda'. Default is 'cpu'.
hard_fail (Union[bool, float], optional): If True or a float, assert benchmarks pass with given metric.
Default is False.
Returns:
df (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size,
metric, and inference time.
"""
import pandas as pd import pandas as pd
pd.options.display.max_columns = 10 pd.options.display.max_columns = 10
pd.options.display.width = 120 pd.options.display.width = 120
@ -61,7 +84,7 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
filename = model.ckpt_path or model.cfg filename = model.ckpt_path or model.cfg
export = model # PyTorch format export = model # PyTorch format
else: else:
filename = model.export(imgsz=imgsz, format=format, half=half, device=device) # all others filename = model.export(imgsz=imgsz, format=format, half=half, int8=int8, device=device) # all others
export = YOLO(filename, task=model.task) export = YOLO(filename, task=model.task)
assert suffix in str(filename), 'export failed' assert suffix in str(filename), 'export failed'
emoji = '' # indicates export succeeded emoji = '' # indicates export succeeded
@ -83,7 +106,14 @@ def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, hal
elif model.task == 'pose': elif model.task == 'pose':
data, key = 'coco8-pose.yaml', 'metrics/mAP50-95(P)' data, key = 'coco8-pose.yaml', 'metrics/mAP50-95(P)'
results = export.val(data=data, batch=1, imgsz=imgsz, plots=False, device=device, half=half, verbose=False) results = export.val(data=data,
batch=1,
imgsz=imgsz,
plots=False,
device=device,
half=half,
int8=int8,
verbose=False)
metric, speed = results.results_dict[key], results.speed['inference'] metric, speed = results.results_dict[key], results.speed['inference']
y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)]) y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e: except Exception as e:

@ -2,111 +2,144 @@
""" """
Base callbacks Base callbacks
""" """
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
# Trainer callbacks ---------------------------------------------------------------------------------------------------- # Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
"""Called before the pretraining routine starts."""
pass pass
def on_pretrain_routine_end(trainer): def on_pretrain_routine_end(trainer):
"""Called after the pretraining routine ends."""
pass pass
def on_train_start(trainer): def on_train_start(trainer):
"""Called when the training starts."""
pass pass
def on_train_epoch_start(trainer): def on_train_epoch_start(trainer):
"""Called at the start of each training epoch."""
pass pass
def on_train_batch_start(trainer): def on_train_batch_start(trainer):
"""Called at the start of each training batch."""
pass pass
def optimizer_step(trainer): def optimizer_step(trainer):
"""Called when the optimizer takes a step."""
pass pass
def on_before_zero_grad(trainer): def on_before_zero_grad(trainer):
"""Called before the gradients are set to zero."""
pass pass
def on_train_batch_end(trainer): def on_train_batch_end(trainer):
"""Called at the end of each training batch."""
pass pass
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
"""Called at the end of each training epoch."""
pass pass
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Called at the end of each fit epoch (train + val)."""
pass pass
def on_model_save(trainer): def on_model_save(trainer):
"""Called when the model is saved."""
pass pass
def on_train_end(trainer): def on_train_end(trainer):
"""Called when the training ends."""
pass pass
def on_params_update(trainer): def on_params_update(trainer):
"""Called when the model parameters are updated."""
pass pass
def teardown(trainer): def teardown(trainer):
"""Called during the teardown of the training process."""
pass pass
# Validator callbacks -------------------------------------------------------------------------------------------------- # Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator): def on_val_start(validator):
"""Called when the validation starts."""
pass pass
def on_val_batch_start(validator): def on_val_batch_start(validator):
"""Called at the start of each validation batch."""
pass pass
def on_val_batch_end(validator): def on_val_batch_end(validator):
"""Called at the end of each validation batch."""
pass pass
def on_val_end(validator): def on_val_end(validator):
"""Called when the validation ends."""
pass pass
# Predictor callbacks -------------------------------------------------------------------------------------------------- # Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor): def on_predict_start(predictor):
"""Called when the prediction starts."""
pass pass
def on_predict_batch_start(predictor): def on_predict_batch_start(predictor):
"""Called at the start of each prediction batch."""
pass pass
def on_predict_batch_end(predictor): def on_predict_batch_end(predictor):
"""Called at the end of each prediction batch."""
pass pass
def on_predict_postprocess_end(predictor): def on_predict_postprocess_end(predictor):
"""Called after the post-processing of the prediction ends."""
pass pass
def on_predict_end(predictor): def on_predict_end(predictor):
"""Called when the prediction ends."""
pass pass
# Exporter callbacks --------------------------------------------------------------------------------------------------- # Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter): def on_export_start(exporter):
"""Called when the model export starts."""
pass pass
def on_export_end(exporter): def on_export_end(exporter):
"""Called when the model export ends."""
pass pass
@ -146,10 +179,23 @@ default_callbacks = {
def get_default_callbacks(): def get_default_callbacks():
"""
Return a copy of the default_callbacks dictionary with lists as default values.
Returns:
(defaultdict): A defaultdict with keys from default_callbacks and empty lists as default values.
"""
return defaultdict(list, deepcopy(default_callbacks)) return defaultdict(list, deepcopy(default_callbacks))
def add_integration_callbacks(instance): def add_integration_callbacks(instance):
"""
Add integration callbacks from various sources to the instance's callbacks.
Args:
instance (Trainer, Predictor, Validator, Exporter): An object with a 'callbacks' attribute that is a dictionary
of callback lists.
"""
from .clearml import callbacks as clearml_callbacks from .clearml import callbacks as clearml_callbacks
from .comet import callbacks as comet_callbacks from .comet import callbacks as comet_callbacks
from .hub import callbacks as hub_callbacks from .hub import callbacks as hub_callbacks

@ -59,6 +59,7 @@ def _log_plot(title, plot_path) -> None:
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
"""Runs at start of pretraining routine; initializes and connects/ logs task to ClearML."""
try: try:
task = Task.current_task() task = Task.current_task()
if task: if task:
@ -83,11 +84,13 @@ def on_pretrain_routine_start(trainer):
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
"""Logs debug samples for the first epoch of YOLO training."""
if trainer.epoch == 1 and Task.current_task(): if trainer.epoch == 1 and Task.current_task():
_log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic') _log_debug_samples(sorted(trainer.save_dir.glob('train_batch*.jpg')), 'Mosaic')
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Reports model information to logger at the end of an epoch."""
task = Task.current_task() task = Task.current_task()
if task: if task:
# You should have access to the validation bboxes under jdict # You should have access to the validation bboxes under jdict
@ -105,12 +108,14 @@ def on_fit_epoch_end(trainer):
def on_val_end(validator): def on_val_end(validator):
"""Logs validation results including labels and predictions."""
if Task.current_task(): if Task.current_task():
# Log val_labels and val_pred # Log val_labels and val_pred
_log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation') _log_debug_samples(sorted(validator.save_dir.glob('val*.jpg')), 'Validation')
def on_train_end(trainer): def on_train_end(trainer):
"""Logs final model and its name on training completion."""
task = Task.current_task() task = Task.current_task()
if task: if task:
# Log final results, CM matrix + PR plots # Log final results, CM matrix + PR plots

@ -36,6 +36,7 @@ _comet_image_prediction_count = 0
def _get_experiment_type(mode, project_name): def _get_experiment_type(mode, project_name):
"""Return an experiment based on mode and project name."""
if mode == 'offline': if mode == 'offline':
return comet_ml.OfflineExperiment(project_name=project_name) return comet_ml.OfflineExperiment(project_name=project_name)
@ -61,6 +62,7 @@ def _create_experiment(args):
def _fetch_trainer_metadata(trainer): def _fetch_trainer_metadata(trainer):
"""Returns metadata for YOLO training including epoch and asset saving status."""
curr_epoch = trainer.epoch + 1 curr_epoch = trainer.epoch + 1
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
@ -97,6 +99,7 @@ def _scale_bounding_box_to_original_image_shape(box, resized_image_shape, origin
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None): def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None):
"""Format ground truth annotations for detection."""
indices = batch['batch_idx'] == img_idx indices = batch['batch_idx'] == img_idx
bboxes = batch['bboxes'][indices] bboxes = batch['bboxes'][indices]
if len(bboxes) == 0: if len(bboxes) == 0:
@ -120,6 +123,7 @@ def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, c
def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None): def _format_prediction_annotations_for_detection(image_path, metadata, class_label_map=None):
"""Format YOLO predictions for object detection visualization."""
stem = image_path.stem stem = image_path.stem
image_id = int(stem) if stem.isnumeric() else stem image_id = int(stem) if stem.isnumeric() else stem
@ -142,6 +146,7 @@ def _format_prediction_annotations_for_detection(image_path, metadata, class_lab
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map): def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map):
"""Join the ground truth and prediction annotations if they exist."""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, ground_truth_annotations = _format_ground_truth_annotations_for_detection(img_idx, image_path, batch,
class_label_map) class_label_map)
prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map, prediction_annotations = _format_prediction_annotations_for_detection(image_path, prediction_metadata_map,
@ -153,6 +158,7 @@ def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, clas
def _create_prediction_metadata_map(model_predictions): def _create_prediction_metadata_map(model_predictions):
"""Create metadata map for model predictions by groupings them based on image ID."""
pred_metadata_map = {} pred_metadata_map = {}
for prediction in model_predictions: for prediction in model_predictions:
pred_metadata_map.setdefault(prediction['image_id'], []) pred_metadata_map.setdefault(prediction['image_id'], [])
@ -162,6 +168,7 @@ def _create_prediction_metadata_map(model_predictions):
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch): def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
"""Log the confusion matrix to Weights and Biases experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data['names'].values()) + ['background'] names = list(trainer.data['names'].values()) + ['background']
experiment.log_confusion_matrix( experiment.log_confusion_matrix(
@ -174,6 +181,7 @@ def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch):
def _log_images(experiment, image_paths, curr_step, annotations=None): def _log_images(experiment, image_paths, curr_step, annotations=None):
"""Logs images to the experiment with optional annotations."""
if annotations: if annotations:
for image_path, annotation in zip(image_paths, annotations): for image_path, annotation in zip(image_paths, annotations):
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation) experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
@ -184,6 +192,7 @@ def _log_images(experiment, image_paths, curr_step, annotations=None):
def _log_image_predictions(experiment, validator, curr_step): def _log_image_predictions(experiment, validator, curr_step):
"""Logs predicted boxes for a single image during training."""
global _comet_image_prediction_count global _comet_image_prediction_count
task = validator.args.task task = validator.args.task
@ -225,6 +234,7 @@ def _log_image_predictions(experiment, validator, curr_step):
def _log_plots(experiment, trainer): def _log_plots(experiment, trainer):
"""Logs evaluation plots and label plots for the experiment."""
plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES] plot_filenames = [trainer.save_dir / f'{plots}.png' for plots in EVALUATION_PLOT_NAMES]
_log_images(experiment, plot_filenames, None) _log_images(experiment, plot_filenames, None)
@ -233,6 +243,7 @@ def _log_plots(experiment, trainer):
def _log_model(experiment, trainer): def _log_model(experiment, trainer):
"""Log the best-trained model to Comet.ml."""
experiment.log_model( experiment.log_model(
COMET_MODEL_NAME, COMET_MODEL_NAME,
file_or_folder=str(trainer.best), file_or_folder=str(trainer.best),
@ -242,12 +253,14 @@ def _log_model(experiment, trainer):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
"""Creates or resumes a CometML experiment at the start of a YOLO pre-training routine."""
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
if not experiment: if not experiment:
_create_experiment(trainer.args) _create_experiment(trainer.args)
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
"""Log metrics and save batch images at the end of training epochs."""
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
if not experiment: if not experiment:
return return
@ -267,6 +280,7 @@ def on_train_epoch_end(trainer):
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Logs model assets at the end of each epoch."""
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
if not experiment: if not experiment:
return return
@ -296,6 +310,7 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer): def on_train_end(trainer):
"""Perform operations at the end of training."""
experiment = comet_ml.get_global_experiment() experiment = comet_ml.get_global_experiment()
if not experiment: if not experiment:
return return

@ -9,6 +9,7 @@ from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
def on_pretrain_routine_end(trainer): def on_pretrain_routine_end(trainer):
"""Logs info before starting timer for upload rate limit."""
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Start timer for upload rate limit # Start timer for upload rate limit
@ -17,6 +18,7 @@ def on_pretrain_routine_end(trainer):
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Uploads training progress metrics at the end of each epoch."""
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Upload metrics after val end # Upload metrics after val end
@ -35,6 +37,7 @@ def on_fit_epoch_end(trainer):
def on_model_save(trainer): def on_model_save(trainer):
"""Saves checkpoints to Ultralytics HUB with rate limiting."""
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Upload checkpoints with rate limiting # Upload checkpoints with rate limiting
@ -46,6 +49,7 @@ def on_model_save(trainer):
def on_train_end(trainer): def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
session = getattr(trainer, 'hub_session', None) session = getattr(trainer, 'hub_session', None)
if session: if session:
# Upload final model and metrics with exponential standoff # Upload final model and metrics with exponential standoff
@ -57,18 +61,22 @@ def on_train_end(trainer):
def on_train_start(trainer): def on_train_start(trainer):
"""Run traces on train start."""
traces(trainer.args, traces_sample_rate=1.0) traces(trainer.args, traces_sample_rate=1.0)
def on_val_start(validator): def on_val_start(validator):
"""Runs traces on validation start."""
traces(validator.args, traces_sample_rate=1.0) traces(validator.args, traces_sample_rate=1.0)
def on_predict_start(predictor): def on_predict_start(predictor):
"""Run traces on predict start."""
traces(predictor.args, traces_sample_rate=1.0) traces(predictor.args, traces_sample_rate=1.0)
def on_export_start(exporter): def on_export_start(exporter):
"""Run traces on export start."""
traces(exporter.args, traces_sample_rate=1.0) traces(exporter.args, traces_sample_rate=1.0)

@ -16,6 +16,7 @@ except (ImportError, AssertionError):
def on_pretrain_routine_end(trainer): def on_pretrain_routine_end(trainer):
"""Logs training parameters to MLflow."""
global mlflow, run, run_id, experiment_name global mlflow, run, run_id, experiment_name
if os.environ.get('MLFLOW_TRACKING_URI') is None: if os.environ.get('MLFLOW_TRACKING_URI') is None:
@ -45,17 +46,20 @@ def on_pretrain_routine_end(trainer):
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Logs training metrics to Mlflow."""
if mlflow: if mlflow:
metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()} metrics_dict = {f"{re.sub('[()]', '', k)}": float(v) for k, v in trainer.metrics.items()}
run.log_metrics(metrics=metrics_dict, step=trainer.epoch) run.log_metrics(metrics=metrics_dict, step=trainer.epoch)
def on_model_save(trainer): def on_model_save(trainer):
"""Logs model and metrics to mlflow on save."""
if mlflow: if mlflow:
run.log_artifact(trainer.last) run.log_artifact(trainer.last)
def on_train_end(trainer): def on_train_end(trainer):
"""Called at end of train loop to log model artifact info."""
if mlflow: if mlflow:
root_dir = Path(__file__).resolve().parents[3] root_dir = Path(__file__).resolve().parents[3]
run.log_artifact(trainer.best) run.log_artifact(trainer.best)

@ -7,6 +7,7 @@ except (ImportError, AssertionError):
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Sends training metrics to Ray Tune at end of each epoch."""
if ray.tune.is_session_enabled(): if ray.tune.is_session_enabled():
metrics = trainer.metrics metrics = trainer.metrics
metrics['epoch'] = trainer.epoch metrics['epoch'] = trainer.epoch

@ -12,12 +12,14 @@ writer = None # TensorBoard SummaryWriter instance
def _log_scalars(scalars, step=0): def _log_scalars(scalars, step=0):
"""Logs scalar values to TensorBoard."""
if writer: if writer:
for k, v in scalars.items(): for k, v in scalars.items():
writer.add_scalar(k, v, step) writer.add_scalar(k, v, step)
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter: if SummaryWriter:
try: try:
global writer global writer
@ -29,10 +31,12 @@ def on_pretrain_routine_start(trainer):
def on_batch_end(trainer): def on_batch_end(trainer):
"""Logs scalar statistics at the end of a training batch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1) _log_scalars(trainer.label_loss_items(trainer.tloss, prefix='train'), trainer.epoch + 1)
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Logs epoch metrics at end of training epoch."""
_log_scalars(trainer.metrics, trainer.epoch + 1) _log_scalars(trainer.metrics, trainer.epoch + 1)

@ -11,11 +11,13 @@ except (ImportError, AssertionError):
def on_pretrain_routine_start(trainer): def on_pretrain_routine_start(trainer):
"""Initiate and start project if module is present."""
wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars( wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars(
trainer.args)) if not wb.run else wb.run trainer.args)) if not wb.run else wb.run
def on_fit_epoch_end(trainer): def on_fit_epoch_end(trainer):
"""Logs training metrics and model information at the end of an epoch."""
wb.run.log(trainer.metrics, step=trainer.epoch + 1) wb.run.log(trainer.metrics, step=trainer.epoch + 1)
if trainer.epoch == 0: if trainer.epoch == 0:
model_info = { model_info = {
@ -26,6 +28,7 @@ def on_fit_epoch_end(trainer):
def on_train_epoch_end(trainer): def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1) wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1: if trainer.epoch == 1:
@ -35,6 +38,7 @@ def on_train_epoch_end(trainer):
def on_train_end(trainer): def on_train_end(trainer):
"""Save the best model as an artifact at end of training."""
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model')
if trainer.best.exists(): if trainer.best.exists():
art.add_file(trainer.best) art.add_file(trainer.best)

@ -295,7 +295,7 @@ def check_file(file, suffix='', download=True, hard=True):
def check_yaml(file, suffix=('.yaml', '.yml'), hard=True): def check_yaml(file, suffix=('.yaml', '.yml'), hard=True):
# Search/download YAML file (if necessary) and return path, checking suffix """Search/download YAML file (if necessary) and return path, checking suffix."""
return check_file(file, suffix, hard=hard) return check_file(file, suffix, hard=hard)
@ -315,6 +315,7 @@ def check_imshow(warn=False):
def check_yolo(verbose=True, device=''): def check_yolo(verbose=True, device=''):
"""Return a human-readable YOLO software and hardware summary."""
from ultralytics.yolo.utils.torch_utils import select_device from ultralytics.yolo.utils.torch_utils import select_device
if is_colab(): if is_colab():

@ -24,6 +24,7 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer): def generate_ddp_file(trainer):
"""Generates a DDP file and returns its file name."""
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1) module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__": content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
@ -43,6 +44,7 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer): def generate_ddp_command(world_size, trainer):
"""Generates and returns command for distributed training."""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume: if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir shutil.rmtree(trainer.save_dir) # remove the save_dir

@ -192,7 +192,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3): def download(url, dir=Path.cwd(), unzip=True, delete=False, curl=False, threads=1, retry=3):
# Multithreaded file download and unzip function, used in data.yaml for autodownload """Downloads and unzips files concurrently if threads > 1, else sequentially."""
dir = Path(dir) dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1: if threads > 1:

@ -6,4 +6,5 @@ from ultralytics.yolo.utils import emojis
class HUBModelError(Exception): class HUBModelError(Exception):
def __init__(self, message='Model not found. Please check model URL and try again.'): def __init__(self, message='Model not found. Please check model URL and try again.'):
"""Create an exception for when a model is not found."""
super().__init__(emojis(message)) super().__init__(emojis(message))

@ -11,13 +11,16 @@ class WorkingDirectory(contextlib.ContextDecorator):
"""Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager.""" """Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager."""
def __init__(self, new_dir): def __init__(self, new_dir):
"""Sets the working directory to 'new_dir' upon instantiation."""
self.dir = new_dir # new dir self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir self.cwd = Path.cwd().resolve() # current dir
def __enter__(self): def __enter__(self):
"""Changes the current directory to the specified directory."""
os.chdir(self.dir) os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Restore the current working directory on context exit."""
os.chdir(self.cwd) os.chdir(self.cwd)

@ -14,6 +14,7 @@ def _ntuple(n):
"""From PyTorch internals.""" """From PyTorch internals."""
def parse(x): def parse(x):
"""Parse bounding boxes format between XYWH and LTWH."""
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n)) return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
return parse return parse
@ -64,6 +65,7 @@ class Bboxes:
# return Bboxes(bboxes, format) # return Bboxes(bboxes, format)
def convert(self, format): def convert(self, format):
"""Converts bounding box format from one type to another."""
assert format in _formats assert format in _formats
if self.format == format: if self.format == format:
return return
@ -77,6 +79,7 @@ class Bboxes:
self.format = format self.format = format
def areas(self): def areas(self):
"""Return box areas."""
self.convert('xyxy') self.convert('xyxy')
return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) return (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1])
@ -125,6 +128,7 @@ class Bboxes:
self.bboxes[:, 3] += offset[3] self.bboxes[:, 3] += offset[3]
def __len__(self): def __len__(self):
"""Return the number of boxes."""
return len(self.bboxes) return len(self.bboxes)
@classmethod @classmethod
@ -202,9 +206,11 @@ class Instances:
self.segments = segments self.segments = segments
def convert_bbox(self, format): def convert_bbox(self, format):
"""Convert bounding box format."""
self._bboxes.convert(format=format) self._bboxes.convert(format=format)
def bbox_areas(self): def bbox_areas(self):
"""Calculate the area of bounding boxes."""
self._bboxes.areas() self._bboxes.areas()
def scale(self, scale_w, scale_h, bbox_only=False): def scale(self, scale_w, scale_h, bbox_only=False):
@ -219,6 +225,7 @@ class Instances:
self.keypoints[..., 1] *= scale_h self.keypoints[..., 1] *= scale_h
def denormalize(self, w, h): def denormalize(self, w, h):
"""Denormalizes boxes, segments, and keypoints from normalized coordinates."""
if not self.normalized: if not self.normalized:
return return
self._bboxes.mul(scale=(w, h, w, h)) self._bboxes.mul(scale=(w, h, w, h))
@ -230,6 +237,7 @@ class Instances:
self.normalized = False self.normalized = False
def normalize(self, w, h): def normalize(self, w, h):
"""Normalize bounding boxes, segments, and keypoints to image dimensions."""
if self.normalized: if self.normalized:
return return
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
@ -279,6 +287,7 @@ class Instances:
) )
def flipud(self, h): def flipud(self, h):
"""Flips the coordinates of bounding boxes, segments, and keypoints vertically."""
if self._bboxes.format == 'xyxy': if self._bboxes.format == 'xyxy':
y1 = self.bboxes[:, 1].copy() y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy() y2 = self.bboxes[:, 3].copy()
@ -291,6 +300,7 @@ class Instances:
self.keypoints[..., 1] = h - self.keypoints[..., 1] self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w): def fliplr(self, w):
"""Reverses the order of the bounding boxes and segments horizontally."""
if self._bboxes.format == 'xyxy': if self._bboxes.format == 'xyxy':
x1 = self.bboxes[:, 0].copy() x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy() x2 = self.bboxes[:, 2].copy()
@ -303,6 +313,7 @@ class Instances:
self.keypoints[..., 0] = w - self.keypoints[..., 0] self.keypoints[..., 0] = w - self.keypoints[..., 0]
def clip(self, w, h): def clip(self, w, h):
"""Clips bounding boxes, segments, and keypoints values to stay within image boundaries."""
ori_format = self._bboxes.format ori_format = self._bboxes.format
self.convert_bbox(format='xyxy') self.convert_bbox(format='xyxy')
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
@ -316,6 +327,7 @@ class Instances:
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def update(self, bboxes, segments=None, keypoints=None): def update(self, bboxes, segments=None, keypoints=None):
"""Updates instance variables."""
new_bboxes = Bboxes(bboxes, format=self._bboxes.format) new_bboxes = Bboxes(bboxes, format=self._bboxes.format)
self._bboxes = new_bboxes self._bboxes = new_bboxes
if segments is not None: if segments is not None:
@ -324,6 +336,7 @@ class Instances:
self.keypoints = keypoints self.keypoints = keypoints
def __len__(self): def __len__(self):
"""Return the length of the instance list."""
return len(self.bboxes) return len(self.bboxes)
@classmethod @classmethod
@ -363,4 +376,5 @@ class Instances:
@property @property
def bboxes(self): def bboxes(self):
"""Return bounding boxes."""
return self._bboxes.bboxes return self._bboxes.bboxes

@ -12,9 +12,11 @@ class VarifocalLoss(nn.Module):
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367.""" """Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
def __init__(self): def __init__(self):
"""Initialize the VarifocalLoss class."""
super().__init__() super().__init__()
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0): def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') * loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
@ -25,6 +27,7 @@ class VarifocalLoss(nn.Module):
class BboxLoss(nn.Module): class BboxLoss(nn.Module):
def __init__(self, reg_max, use_dfl=False): def __init__(self, reg_max, use_dfl=False):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__() super().__init__()
self.reg_max = reg_max self.reg_max = reg_max
self.use_dfl = use_dfl self.use_dfl = use_dfl
@ -64,6 +67,7 @@ class KeypointLoss(nn.Module):
self.sigmas = sigmas self.sigmas = sigmas
def forward(self, pred_kpts, gt_kpts, kpt_mask, area): def forward(self, pred_kpts, gt_kpts, kpt_mask, area):
"""Calculates keypoint loss factor and Euclidean distance loss for predicted and actual keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2 d = (pred_kpts[..., 0] - gt_kpts[..., 0]) ** 2 + (pred_kpts[..., 1] - gt_kpts[..., 1]) ** 2
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9) kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0)) / (torch.sum(kpt_mask != 0) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula

@ -180,6 +180,7 @@ class FocalLoss(nn.Module):
"""Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).""" """Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)."""
def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
"""Initialize FocalLoss object with given loss function and hyperparameters."""
super().__init__() super().__init__()
self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
self.gamma = gamma self.gamma = gamma
@ -188,6 +189,7 @@ class FocalLoss(nn.Module):
self.loss_fcn.reduction = 'none' # required to apply FL to each element self.loss_fcn.reduction = 'none' # required to apply FL to each element
def forward(self, pred, true): def forward(self, pred, true):
"""Calculates and updates confusion matrix for object detection/classification tasks."""
loss = self.loss_fcn(pred, true) loss = self.loss_fcn(pred, true)
# p_t = torch.exp(-loss) # p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
@ -220,6 +222,7 @@ class ConfusionMatrix:
""" """
def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'): def __init__(self, nc, conf=0.25, iou_thres=0.45, task='detect'):
"""Initialize attributes for the YOLO model."""
self.task = task self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc)) self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
self.nc = nc # number of classes self.nc = nc # number of classes
@ -285,9 +288,11 @@ class ConfusionMatrix:
self.matrix[dc, self.nc] += 1 # predicted background self.matrix[dc, self.nc] += 1 # predicted background
def matrix(self): def matrix(self):
"""Returns the confusion matrix."""
return self.matrix return self.matrix
def tp_fp(self): def tp_fp(self):
"""Returns true positives and false positives."""
tp = self.matrix.diagonal() # true positives tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections) # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
@ -679,6 +684,7 @@ class DetMetrics(SimpleClass):
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def process(self, tp, conf, pred_cls, target_cls): def process(self, tp, conf, pred_cls, target_cls):
"""Process predicted results for object detection and update metrics."""
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir, results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:] names=self.names)[2:]
self.box.nc = len(self.names) self.box.nc = len(self.names)
@ -686,28 +692,35 @@ class DetMetrics(SimpleClass):
@property @property
def keys(self): def keys(self):
"""Returns a list of keys for accessing specific metrics."""
return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)'] return ['metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)']
def mean_results(self): def mean_results(self):
"""Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
return self.box.mean_results() return self.box.mean_results()
def class_result(self, i): def class_result(self, i):
"""Return the result of evaluating the performance of an object detection model on a specific class."""
return self.box.class_result(i) return self.box.class_result(i)
@property @property
def maps(self): def maps(self):
"""Returns mean Average Precision (mAP) scores per class."""
return self.box.maps return self.box.maps
@property @property
def fitness(self): def fitness(self):
"""Returns the fitness of box object."""
return self.box.fitness() return self.box.fitness()
@property @property
def ap_class_index(self): def ap_class_index(self):
"""Returns the average precision index per class."""
return self.box.ap_class_index return self.box.ap_class_index
@property @property
def results_dict(self): def results_dict(self):
"""Returns dictionary of computed performance metrics and statistics."""
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
@ -781,22 +794,27 @@ class SegmentMetrics(SimpleClass):
@property @property
def keys(self): def keys(self):
"""Returns a list of keys for accessing metrics."""
return [ return [
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)'] 'metrics/precision(M)', 'metrics/recall(M)', 'metrics/mAP50(M)', 'metrics/mAP50-95(M)']
def mean_results(self): def mean_results(self):
"""Return the mean metrics for bounding box and segmentation results."""
return self.box.mean_results() + self.seg.mean_results() return self.box.mean_results() + self.seg.mean_results()
def class_result(self, i): def class_result(self, i):
"""Returns classification results for a specified class index."""
return self.box.class_result(i) + self.seg.class_result(i) return self.box.class_result(i) + self.seg.class_result(i)
@property @property
def maps(self): def maps(self):
"""Returns mAP scores for object detection and semantic segmentation models."""
return self.box.maps + self.seg.maps return self.box.maps + self.seg.maps
@property @property
def fitness(self): def fitness(self):
"""Get the fitness score for both segmentation and bounding box models."""
return self.seg.fitness() + self.box.fitness() return self.seg.fitness() + self.box.fitness()
@property @property
@ -806,6 +824,7 @@ class SegmentMetrics(SimpleClass):
@property @property
def results_dict(self): def results_dict(self):
"""Returns results of object detection model for evaluation."""
return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness])) return dict(zip(self.keys + ['fitness'], self.mean_results() + [self.fitness]))
@ -846,6 +865,7 @@ class PoseMetrics(SegmentMetrics):
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr): def __getattr__(self, attr):
"""Raises an AttributeError if an invalid attribute is accessed."""
name = self.__class__.__name__ name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@ -884,22 +904,27 @@ class PoseMetrics(SegmentMetrics):
@property @property
def keys(self): def keys(self):
"""Returns list of evaluation metric keys."""
return [ return [
'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)', 'metrics/precision(B)', 'metrics/recall(B)', 'metrics/mAP50(B)', 'metrics/mAP50-95(B)',
'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)'] 'metrics/precision(P)', 'metrics/recall(P)', 'metrics/mAP50(P)', 'metrics/mAP50-95(P)']
def mean_results(self): def mean_results(self):
"""Return the mean results of box and pose."""
return self.box.mean_results() + self.pose.mean_results() return self.box.mean_results() + self.pose.mean_results()
def class_result(self, i): def class_result(self, i):
"""Return the class-wise detection results for a specific class i."""
return self.box.class_result(i) + self.pose.class_result(i) return self.box.class_result(i) + self.pose.class_result(i)
@property @property
def maps(self): def maps(self):
"""Returns the mean average precision (mAP) per class for both box and pose detections."""
return self.box.maps + self.pose.maps return self.box.maps + self.pose.maps
@property @property
def fitness(self): def fitness(self):
"""Computes classification metrics and speed using the `targets` and `pred` inputs."""
return self.pose.fitness() + self.box.fitness() return self.pose.fitness() + self.box.fitness()
@ -935,12 +960,15 @@ class ClassifyMetrics(SimpleClass):
@property @property
def fitness(self): def fitness(self):
"""Returns top-5 accuracy as fitness score."""
return self.top5 return self.top5
@property @property
def results_dict(self): def results_dict(self):
"""Returns a dictionary with model's performance metrics and fitness score."""
return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness])) return dict(zip(self.keys + ['fitness'], [self.top1, self.top5, self.fitness]))
@property @property
def keys(self): def keys(self):
"""Returns a list of keys for the results_dict property."""
return ['metrics/accuracy_top1', 'metrics/accuracy_top5'] return ['metrics/accuracy_top1', 'metrics/accuracy_top5']

@ -33,6 +33,7 @@ class Colors:
dtype=np.uint8) dtype=np.uint8)
def __call__(self, i, bgr=False): def __call__(self, i, bgr=False):
"""Converts hex color codes to rgb values."""
c = self.palette[int(i) % self.n] c = self.palette[int(i) % self.n]
return (c[2], c[1], c[0]) if bgr else c return (c[2], c[1], c[0]) if bgr else c
@ -47,6 +48,7 @@ colors = Colors() # create instance for 'from utils.plots import colors'
class Annotator: class Annotator:
# YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations # YOLOv8 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'): def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
"""Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.' assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
self.pil = pil or non_ascii self.pil = pil or non_ascii
@ -71,7 +73,7 @@ class Annotator:
self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)): def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
# Add one xyxy box to image with label """Add one xyxy box to image with label."""
if isinstance(box, torch.Tensor): if isinstance(box, torch.Tensor):
box = box.tolist() box = box.tolist()
if self.pil or not is_ascii(label): if self.pil or not is_ascii(label):
@ -191,7 +193,7 @@ class Annotator:
self.draw.rectangle(xy, fill, outline, width) self.draw.rectangle(xy, fill, outline, width)
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'): def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
# Add text to image (PIL-only) """Adds text to an image using PIL or cv2."""
if anchor == 'bottom': # start y from font bottom if anchor == 'bottom': # start y from font bottom
w, h = self.font.getsize(text) # text width, height w, h = self.font.getsize(text) # text width, height
xy[1] += 1 - h xy[1] += 1 - h
@ -214,6 +216,7 @@ class Annotator:
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
@plt_settings() @plt_settings()
def plot_labels(boxes, cls, names=(), save_dir=Path('')): def plot_labels(boxes, cls, names=(), save_dir=Path('')):
"""Save and plot image with no axis or spines."""
import pandas as pd import pandas as pd
import seaborn as sn import seaborn as sn
@ -260,7 +263,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path('')):
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True): def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop """Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop."""
b = xyxy2xywh(xyxy.view(-1, 4)) # boxes b = xyxy2xywh(xyxy.view(-1, 4)) # boxes
if square: if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square

@ -69,6 +69,7 @@ class TaskAlignedAssigner(nn.Module):
""" """
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9):
"""Initialize a TaskAlignedAssigner object with customizable hyperparameters."""
super().__init__() super().__init__()
self.topk = topk self.topk = topk
self.num_classes = num_classes self.num_classes = num_classes
@ -137,6 +138,7 @@ class TaskAlignedAssigner(nn.Module):
return mask_pos, align_metric, overlaps return mask_pos, align_metric, overlaps
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
"""Compute alignment metric given predicted and ground truth bounding boxes."""
na = pd_bboxes.shape[-2] na = pd_bboxes.shape[-2]
mask_gt = mask_gt.bool() # b, max_num_obj, h*w mask_gt = mask_gt.bool() # b, max_num_obj, h*w
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)

@ -43,6 +43,7 @@ def smart_inference_mode():
"""Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.""" """Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
def decorate(fn): def decorate(fn):
"""Applies appropriate torch decorator for inference mode based on torch version."""
return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
return decorate return decorate
@ -232,7 +233,7 @@ def make_divisible(x, divisor):
def copy_attr(a, b, include=(), exclude=()): def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from 'b' to 'a', options to only include [...] and to exclude [...] """Copies attributes from object 'b' to object 'a', with options to include/exclude certain attributes."""
for k, v in b.__dict__.items(): for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude: if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue continue
@ -246,7 +247,7 @@ def get_latest_opset():
def intersect_dicts(da, db, exclude=()): def intersect_dicts(da, db, exclude=()):
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values """Returns a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values."""
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
@ -310,7 +311,7 @@ class ModelEMA:
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes """Updates attributes and saves stripped model with optimizer removed."""
if self.enabled: if self.enabled:
copy_attr(self.ema, model, include, exclude) copy_attr(self.ema, model, include, exclude)

@ -10,10 +10,12 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
class ClassificationPredictor(BasePredictor): class ClassificationPredictor(BasePredictor):
def preprocess(self, img): def preprocess(self, img):
"""Converts input image to model-compatible data type."""
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions to return Results objects."""
results = [] results = []
for i, pred in enumerate(preds): for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
@ -25,6 +27,7 @@ class ClassificationPredictor(BasePredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
"""Run YOLO model predictions on input images/videos."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg' else 'https://ultralytics.com/images/bus.jpg'

@ -14,15 +14,18 @@ from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides['task'] = 'classify' overrides['task'] = 'classify'
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self): def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data['names'] self.model.names = self.data['names']
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
@ -69,6 +72,7 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume return # dont return ckpt. Classification doesn't support resume
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'): def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
loader = build_classification_dataloader(path=dataset_path, loader = build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz, imgsz=self.args.imgsz,
batch_size=batch_size if mode == 'train' else (batch_size * 2), batch_size=batch_size if mode == 'train' else (batch_size * 2),
@ -84,19 +88,23 @@ class ClassificationTrainer(BaseTrainer):
return loader return loader
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch['img'] = batch['img'].to(self.device) batch['img'] = batch['img'].to(self.device)
batch['cls'] = batch['cls'].to(self.device) batch['cls'] = batch['cls'].to(self.device)
return batch return batch
def progress_string(self): def progress_string(self):
"""Returns a formatted string showing training progress."""
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \ return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self): def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss'] self.loss_names = ['loss']
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir) return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
def criterion(self, preds, batch): def criterion(self, preds, batch):
"""Compute the classification loss between predictions and true labels."""
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
loss_items = loss.detach() loss_items = loss.detach()
return loss, loss_items return loss, loss_items
@ -113,9 +121,11 @@ class ClassificationTrainer(BaseTrainer):
return dict(zip(keys, loss_items)) return dict(zip(keys, loss_items))
def resume_training(self, ckpt): def resume_training(self, ckpt):
"""Resumes training from a given checkpoint."""
pass pass
def final_eval(self): def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best: for f in self.last, self.best:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
@ -130,6 +140,7 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist") data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''

@ -9,14 +9,17 @@ from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
class ClassificationValidator(BaseValidator): class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'classify' self.args.task = 'classify'
self.metrics = ClassifyMetrics() self.metrics = ClassifyMetrics()
def get_desc(self): def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc') return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model): def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names self.names = model.names
self.nc = len(model.names) self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify') self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
@ -24,17 +27,20 @@ class ClassificationValidator(BaseValidator):
self.targets = [] self.targets = []
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
batch['img'] = batch['img'].to(self.device, non_blocking=True) batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float() batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device) batch['cls'] = batch['cls'].to(self.device)
return batch return batch
def update_metrics(self, preds, batch): def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.model.names), 5) n5 = min(len(self.model.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5]) self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls']) self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets) self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots: if self.args.plots:
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values())) self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
@ -42,10 +48,12 @@ class ClassificationValidator(BaseValidator):
self.metrics.confusion_matrix = self.confusion_matrix self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self): def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred) self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict return self.metrics.results_dict
def get_dataloader(self, dataset_path, batch_size): def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""
return build_classification_dataloader(path=dataset_path, return build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz, imgsz=self.args.imgsz,
batch_size=batch_size, batch_size=batch_size,
@ -54,11 +62,13 @@ class ClassificationValidator(BaseValidator):
workers=self.args.workers) workers=self.args.workers)
def print_results(self): def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5)) LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate YOLO model using custom data."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' data = cfg.data or 'mnist160'

@ -10,12 +10,14 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT, ops
class DetectionPredictor(BasePredictor): class DetectionPredictor(BasePredictor):
def preprocess(self, img): def preprocess(self, img):
"""Convert an image to PyTorch tensor and normalize pixel values."""
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0 img /= 255 # 0 - 255 to 0.0 - 1.0
return img return img
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -35,6 +37,7 @@ class DetectionPredictor(BasePredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO model inference on input image(s)."""
model = cfg.model or 'yolov8n.pt' model = cfg.model or 'yolov8n.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg' else 'https://ultralytics.com/images/bus.jpg'

@ -44,6 +44,7 @@ class DetectionTrainer(BaseTrainer):
rect=mode == 'val', data_info=self.data)[0] rect=mode == 'val', data_info=self.data)[0]
def preprocess_batch(self, batch): def preprocess_batch(self, batch):
"""Preprocesses a batch of images by scaling and converting to float."""
batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255 batch['img'] = batch['img'].to(self.device, non_blocking=True).float() / 255
return batch return batch
@ -58,16 +59,19 @@ class DetectionTrainer(BaseTrainer):
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
"""Return a YOLO detection model."""
model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1) model = DetectionModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
return model return model
def get_validator(self): def get_validator(self):
"""Returns a DetectionValidator for YOLO model validation."""
self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'cls_loss', 'dfl_loss'
return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.detect.DetectionValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch): def criterion(self, preds, batch):
"""Compute loss for YOLO prediction and ground-truth."""
if not hasattr(self, 'compute_loss'): if not hasattr(self, 'compute_loss'):
self.compute_loss = Loss(de_parallel(self.model)) self.compute_loss = Loss(de_parallel(self.model))
return self.compute_loss(preds, batch) return self.compute_loss(preds, batch)
@ -85,10 +89,12 @@ class DetectionTrainer(BaseTrainer):
return keys return keys
def progress_string(self): def progress_string(self):
"""Returns a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ('\n' + '%11s' * return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size') (4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'], plot_images(images=batch['img'],
batch_idx=batch['batch_idx'], batch_idx=batch['batch_idx'],
cls=batch['cls'].squeeze(-1), cls=batch['cls'].squeeze(-1),
@ -97,9 +103,11 @@ class DetectionTrainer(BaseTrainer):
fname=self.save_dir / f'train_batch{ni}.jpg') fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self): def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv) # save results.png plot_results(file=self.csv) # save results.png
def plot_training_labels(self): def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0) boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0) cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir) plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir)
@ -129,6 +137,7 @@ class Loss:
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets, batch_size, scale_tensor): def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
if targets.shape[0] == 0: if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 5, device=self.device) out = torch.zeros(batch_size, 0, 5, device=self.device)
else: else:
@ -145,6 +154,7 @@ class Loss:
return out return out
def bbox_decode(self, anchor_points, pred_dist): def bbox_decode(self, anchor_points, pred_dist):
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl: if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
@ -153,6 +163,7 @@ class Loss:
return dist2bbox(pred_dist, anchor_points, xywh=False) return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds, batch): def __call__(self, preds, batch):
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
@ -199,6 +210,7 @@ class Loss:
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8n.pt' model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist") data = cfg.data or 'coco128.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''

@ -19,6 +19,7 @@ from ultralytics.yolo.utils.torch_utils import de_parallel
class DetectionValidator(BaseValidator): class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'detect' self.args.task = 'detect'
self.is_coco = False self.is_coco = False
@ -28,6 +29,7 @@ class DetectionValidator(BaseValidator):
self.niou = self.iouv.numel() self.niou = self.iouv.numel()
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses batch of images for YOLO training."""
batch['img'] = batch['img'].to(self.device, non_blocking=True) batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255 batch['img'] = (batch['img'].half() if self.args.half else batch['img'].float()) / 255
for k in ['batch_idx', 'cls', 'bboxes']: for k in ['batch_idx', 'cls', 'bboxes']:
@ -40,6 +42,7 @@ class DetectionValidator(BaseValidator):
return batch return batch
def init_metrics(self, model): def init_metrics(self, model):
"""Initialize evaluation metrics for YOLO."""
val = self.data.get(self.args.split, '') # validation path val = self.data.get(self.args.split, '') # validation path
self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO self.is_coco = isinstance(val, str) and 'coco' in val and val.endswith(f'{os.sep}val2017.txt') # is COCO
self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000)) self.class_map = ops.coco80_to_coco91_class() if self.is_coco else list(range(1000))
@ -54,9 +57,11 @@ class DetectionValidator(BaseValidator):
self.stats = [] self.stats = []
def get_desc(self): def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)') return ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -113,10 +118,12 @@ class DetectionValidator(BaseValidator):
self.save_one_txt(predn, self.args.save_conf, shape, file) self.save_one_txt(predn, self.args.save_conf, shape, file)
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
"""Set final values for metrics speed and confusion matrix."""
self.metrics.speed = self.speed self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self): def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*self.stats)] # to numpy
if len(stats) and stats[0].any(): if len(stats) and stats[0].any():
self.metrics.process(*stats) self.metrics.process(*stats)
@ -124,6 +131,7 @@ class DetectionValidator(BaseValidator):
return self.metrics.results_dict return self.metrics.results_dict
def print_results(self): def print_results(self):
"""Prints training/validation set metrics per class."""
pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format pf = '%22s' + '%11i' * 2 + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results())) LOGGER.info(pf % ('all', self.seen, self.nt_per_class.sum(), *self.metrics.mean_results()))
if self.nt_per_class.sum() == 0: if self.nt_per_class.sum() == 0:
@ -183,6 +191,7 @@ class DetectionValidator(BaseValidator):
mode='val')[0] mode='val')[0]
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(batch['img'], plot_images(batch['img'],
batch['batch_idx'], batch['batch_idx'],
batch['cls'].squeeze(-1), batch['cls'].squeeze(-1),
@ -192,6 +201,7 @@ class DetectionValidator(BaseValidator):
names=self.names) names=self.names)
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'], plot_images(batch['img'],
*output_to_target(preds, max_det=15), *output_to_target(preds, max_det=15),
paths=batch['im_file'], paths=batch['im_file'],
@ -199,6 +209,7 @@ class DetectionValidator(BaseValidator):
names=self.names) # pred names=self.names) # pred
def save_one_txt(self, predn, save_conf, shape, file): def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls in predn.tolist(): for *xyxy, conf, cls in predn.tolist():
xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh xywh = (ops.xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
@ -207,6 +218,7 @@ class DetectionValidator(BaseValidator):
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')
def pred_to_json(self, predn, filename): def pred_to_json(self, predn, filename):
"""Serialize YOLO predictions to COCO json format."""
stem = Path(filename).stem stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh box = ops.xyxy2xywh(predn[:, :4]) # xywh
@ -219,6 +231,7 @@ class DetectionValidator(BaseValidator):
'score': round(p[4], 5)}) 'score': round(p[4], 5)})
def eval_json(self, stats): def eval_json(self, stats):
"""Evaluates YOLO output in JSON format and returns performance statistics."""
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions pred_json = self.save_dir / 'predictions.json' # predictions
@ -245,6 +258,7 @@ class DetectionValidator(BaseValidator):
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation dataset."""
model = cfg.model or 'yolov8n.pt' model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco128.yaml' data = cfg.data or 'coco128.yaml'

@ -8,6 +8,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
class PosePredictor(DetectionPredictor): class PosePredictor(DetectionPredictor):
def postprocess(self, preds, img, orig_img): def postprocess(self, preds, img, orig_img):
"""Return detection results for a given input image or list of images."""
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -35,6 +36,7 @@ class PosePredictor(DetectionPredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO to predict objects in an image or video."""
model = cfg.model or 'yolov8n-pose.pt' model = cfg.model or 'yolov8n-pose.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg' else 'https://ultralytics.com/images/bus.jpg'

@ -21,12 +21,14 @@ from ultralytics.yolo.v8.detect.train import Loss
class PoseTrainer(v8.detect.DetectionTrainer): class PoseTrainer(v8.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a PoseTrainer object with specified configurations and overrides."""
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides['task'] = 'pose' overrides['task'] = 'pose'
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
"""Get pose estimation model with specified configuration and weights."""
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose) model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
if weights: if weights:
model.load(weights) model.load(weights)
@ -34,19 +36,23 @@ class PoseTrainer(v8.detect.DetectionTrainer):
return model return model
def set_model_attributes(self): def set_model_attributes(self):
"""Sets keypoints shape attribute of PoseModel."""
super().set_model_attributes() super().set_model_attributes()
self.model.kpt_shape = self.data['kpt_shape'] self.model.kpt_shape = self.data['kpt_shape']
def get_validator(self): def get_validator(self):
"""Returns an instance of the PoseValidator class for validation."""
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch): def criterion(self, preds, batch):
"""Computes pose loss for the YOLO model."""
if not hasattr(self, 'compute_loss'): if not hasattr(self, 'compute_loss'):
self.compute_loss = PoseLoss(de_parallel(self.model)) self.compute_loss = PoseLoss(de_parallel(self.model))
return self.compute_loss(preds, batch) return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
images = batch['img'] images = batch['img']
kpts = batch['keypoints'] kpts = batch['keypoints']
cls = batch['cls'].squeeze(-1) cls = batch['cls'].squeeze(-1)
@ -62,6 +68,7 @@ class PoseTrainer(v8.detect.DetectionTrainer):
fname=self.save_dir / f'train_batch{ni}.jpg') fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self): def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True) # save results.png plot_results(file=self.csv, pose=True) # save results.png
@ -78,6 +85,7 @@ class PoseLoss(Loss):
self.keypoint_loss = KeypointLoss(sigmas=sigmas) self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds, batch): def __call__(self, preds, batch):
"""Calculate the total loss and detach it."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
@ -145,6 +153,7 @@ class PoseLoss(Loss):
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl) return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
def kpts_decode(self, anchor_points, pred_kpts): def kpts_decode(self, anchor_points, pred_kpts):
"""Decodes predicted keypoints to image coordinates."""
y = pred_kpts.clone() y = pred_kpts.clone()
y[..., :2] *= 2.0 y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5 y[..., 0] += anchor_points[:, [0]] - 0.5
@ -153,6 +162,7 @@ class PoseLoss(Loss):
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO model on the given data and device."""
model = cfg.model or 'yolov8n-pose.yaml' model = cfg.model or 'yolov8n-pose.yaml'
data = cfg.data or 'coco8-pose.yaml' data = cfg.data or 'coco8-pose.yaml'
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''

@ -15,20 +15,24 @@ from ultralytics.yolo.v8.detect import DetectionValidator
class PoseValidator(DetectionValidator): class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose' self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir) self.metrics = PoseMetrics(save_dir=self.save_dir)
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
batch = super().preprocess(batch) batch = super().preprocess(batch)
batch['keypoints'] = batch['keypoints'].to(self.device).float() batch['keypoints'] = batch['keypoints'].to(self.device).float()
return batch return batch
def get_desc(self): def get_desc(self):
"""Returns description of evaluation metrics in string format."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P', return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
'R', 'mAP50', 'mAP50-95)') 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
"""Apply non-maximum suppression and return detections with high confidence scores."""
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -40,6 +44,7 @@ class PoseValidator(DetectionValidator):
return preds return preds
def init_metrics(self, model): def init_metrics(self, model):
"""Initiate pose estimation metrics for YOLO model."""
super().init_metrics(model) super().init_metrics(model)
self.kpt_shape = self.data['kpt_shape'] self.kpt_shape = self.data['kpt_shape']
is_pose = self.kpt_shape == [17, 3] is_pose = self.kpt_shape == [17, 3]
@ -137,6 +142,7 @@ class PoseValidator(DetectionValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device) return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
plot_images(batch['img'], plot_images(batch['img'],
batch['batch_idx'], batch['batch_idx'],
batch['cls'].squeeze(-1), batch['cls'].squeeze(-1),
@ -147,6 +153,7 @@ class PoseValidator(DetectionValidator):
names=self.names) names=self.names)
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0) pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0)
plot_images(batch['img'], plot_images(batch['img'],
*output_to_target(preds, max_det=15), *output_to_target(preds, max_det=15),
@ -156,6 +163,7 @@ class PoseValidator(DetectionValidator):
names=self.names) # pred names=self.names) # pred
def pred_to_json(self, predn, filename): def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""
stem = Path(filename).stem stem = Path(filename).stem
image_id = int(stem) if stem.isnumeric() else stem image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn[:, :4]) # xywh box = ops.xyxy2xywh(predn[:, :4]) # xywh
@ -169,6 +177,7 @@ class PoseValidator(DetectionValidator):
'score': round(p[4], 5)}) 'score': round(p[4], 5)})
def eval_json(self, stats): def eval_json(self, stats):
"""Evaluates object detection model using COCO JSON format."""
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions pred_json = self.save_dir / 'predictions.json' # predictions
@ -197,6 +206,7 @@ class PoseValidator(DetectionValidator):
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
"""Performs validation on YOLO model using given data."""
model = cfg.model or 'yolov8n-pose.pt' model = cfg.model or 'yolov8n-pose.pt'
data = cfg.data or 'coco8-pose.yaml' data = cfg.data or 'coco8-pose.yaml'

@ -41,6 +41,7 @@ class SegmentationPredictor(DetectionPredictor):
def predict(cfg=DEFAULT_CFG, use_python=False): def predict(cfg=DEFAULT_CFG, use_python=False):
"""Runs YOLO object detection on an image or video source."""
model = cfg.model or 'yolov8n-seg.pt' model = cfg.model or 'yolov8n-seg.pt'
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \ source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg' else 'https://ultralytics.com/images/bus.jpg'

@ -18,12 +18,14 @@ from ultralytics.yolo.v8.detect.train import Loss
class SegmentationTrainer(v8.detect.DetectionTrainer): class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a SegmentationTrainer object with given arguments."""
if overrides is None: if overrides is None:
overrides = {} overrides = {}
overrides['task'] = 'segment' overrides['task'] = 'segment'
super().__init__(cfg, overrides, _callbacks) super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose=True): def get_model(self, cfg=None, weights=None, verbose=True):
"""Return SegmentationModel initialized with specified config and weights."""
model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1) model = SegmentationModel(cfg, ch=3, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights: if weights:
model.load(weights) model.load(weights)
@ -31,15 +33,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
return model return model
def get_validator(self): def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss' self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args)) return v8.segment.SegmentationValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def criterion(self, preds, batch): def criterion(self, preds, batch):
"""Returns the computed loss using the SegLoss class on the given predictions and batch."""
if not hasattr(self, 'compute_loss'): if not hasattr(self, 'compute_loss'):
self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask) self.compute_loss = SegLoss(de_parallel(self.model), overlap=self.args.overlap_mask)
return self.compute_loss(preds, batch) return self.compute_loss(preds, batch)
def plot_training_samples(self, batch, ni): def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
images = batch['img'] images = batch['img']
masks = batch['masks'] masks = batch['masks']
cls = batch['cls'].squeeze(-1) cls = batch['cls'].squeeze(-1)
@ -49,6 +54,7 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg') plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
def plot_metrics(self): def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True) # save results.png plot_results(file=self.csv, segment=True) # save results.png
@ -61,6 +67,7 @@ class SegLoss(Loss):
self.overlap = overlap self.overlap = overlap
def __call__(self, preds, batch): def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl loss = torch.zeros(4, device=self.device) # box, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
@ -147,6 +154,7 @@ class SegLoss(Loss):
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train a YOLO segmentation model based on passed arguments."""
model = cfg.model or 'yolov8n-seg.pt' model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist") data = cfg.data or 'coco128-seg.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''

@ -17,16 +17,19 @@ from ultralytics.yolo.v8.detect import DetectionValidator
class SegmentationValidator(DetectionValidator): class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment' self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir) self.metrics = SegmentMetrics(save_dir=self.save_dir)
def preprocess(self, batch): def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
batch = super().preprocess(batch) batch = super().preprocess(batch)
batch['masks'] = batch['masks'].to(self.device).float() batch['masks'] = batch['masks'].to(self.device).float()
return batch return batch
def init_metrics(self, model): def init_metrics(self, model):
"""Initialize metrics and select mask processing function based on save_json flag."""
super().init_metrics(model) super().init_metrics(model)
self.plot_masks = [] self.plot_masks = []
if self.args.save_json: if self.args.save_json:
@ -36,10 +39,12 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask # faster self.process = ops.process_mask # faster
def get_desc(self): def get_desc(self):
"""Return a formatted description of evaluation metrics."""
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P', return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Mask(P',
'R', 'mAP50', 'mAP50-95)') 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,
@ -119,6 +124,7 @@ class SegmentationValidator(DetectionValidator):
# save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') # save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
def finalize_metrics(self, *args, **kwargs): def finalize_metrics(self, *args, **kwargs):
"""Sets speed and confusion matrix for evaluation metrics."""
self.metrics.speed = self.speed self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix self.metrics.confusion_matrix = self.confusion_matrix
@ -160,6 +166,7 @@ class SegmentationValidator(DetectionValidator):
return torch.tensor(correct, dtype=torch.bool, device=detections.device) return torch.tensor(correct, dtype=torch.bool, device=detections.device)
def plot_val_samples(self, batch, ni): def plot_val_samples(self, batch, ni):
"""Plots validation samples with bounding box labels."""
plot_images(batch['img'], plot_images(batch['img'],
batch['batch_idx'], batch['batch_idx'],
batch['cls'].squeeze(-1), batch['cls'].squeeze(-1),
@ -170,6 +177,7 @@ class SegmentationValidator(DetectionValidator):
names=self.names) names=self.names)
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(batch['img'], plot_images(batch['img'],
*output_to_target(preds[0], max_det=15), *output_to_target(preds[0], max_det=15),
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
@ -184,6 +192,7 @@ class SegmentationValidator(DetectionValidator):
from pycocotools.mask import encode # noqa from pycocotools.mask import encode # noqa
def single_encode(x): def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0] rle = encode(np.asarray(x[:, :, None], order='F', dtype='uint8'))[0]
rle['counts'] = rle['counts'].decode('utf-8') rle['counts'] = rle['counts'].decode('utf-8')
return rle return rle
@ -204,6 +213,7 @@ class SegmentationValidator(DetectionValidator):
'segmentation': rles[i]}) 'segmentation': rles[i]})
def eval_json(self, stats): def eval_json(self, stats):
"""Return COCO-style object detection evaluation metrics."""
if self.args.save_json and self.is_coco and len(self.jdict): if self.args.save_json and self.is_coco and len(self.jdict):
anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations anno_json = self.data['path'] / 'annotations/instances_val2017.json' # annotations
pred_json = self.save_dir / 'predictions.json' # predictions pred_json = self.save_dir / 'predictions.json' # predictions
@ -232,6 +242,7 @@ class SegmentationValidator(DetectionValidator):
def val(cfg=DEFAULT_CFG, use_python=False): def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate trained YOLO model on validation data."""
model = cfg.model or 'yolov8n-seg.pt' model = cfg.model or 'yolov8n-seg.pt'
data = cfg.data or 'coco128-seg.yaml' data = cfg.data or 'coco128-seg.yaml'

Loading…
Cancel
Save