|
|
|
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
|
|
|
|
"""
|
|
|
|
Common modules
|
|
|
|
"""
|
|
|
|
|
|
|
|
import math
|
|
|
|
import warnings
|
|
|
|
from copy import copy
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
import requests
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from PIL import Image, ImageOps
|
|
|
|
from torch.cuda import amp
|
|
|
|
|
|
|
|
from ultralytics.yolo.data.augment import LetterBox
|
|
|
|
from ultralytics.yolo.utils import LOGGER, colorstr
|
|
|
|
from ultralytics.yolo.utils.checks import check_version
|
|
|
|
from ultralytics.yolo.utils.files import increment_path
|
|
|
|
from ultralytics.yolo.utils.ops import Profile, make_divisible, non_max_suppression, scale_boxes, xyxy2xywh
|
|
|
|
from ultralytics.yolo.utils.plotting import Annotator, colors, save_one_box
|
|
|
|
from ultralytics.yolo.utils.torch_utils import copy_attr, smart_inference_mode
|
|
|
|
|
|
|
|
from .autobackend import AutoBackend
|
|
|
|
|
|
|
|
# from utils.plots import feature_visualization TODO
|
|
|
|
|
|
|
|
|
|
|
|
def autopad(k, p=None, d=1): # kernel, padding, dilation
|
|
|
|
# Pad to 'same' shape outputs
|
|
|
|
if d > 1:
|
|
|
|
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
|
|
|
|
if p is None:
|
|
|
|
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
|
|
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
class Conv(nn.Module):
|
|
|
|
# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
|
|
|
|
default_act = nn.SiLU() # default activation
|
|
|
|
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
|
|
|
|
self.bn = nn.BatchNorm2d(c2)
|
|
|
|
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.act(self.bn(self.conv(x)))
|
|
|
|
|
|
|
|
def forward_fuse(self, x):
|
|
|
|
return self.act(self.conv(x))
|
|
|
|
|
|
|
|
|
|
|
|
class DWConv(Conv):
|
|
|
|
# Depth-wise convolution
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
|
|
|
|
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
|
|
|
|
|
|
|
|
|
|
|
|
class DWConvTranspose2d(nn.ConvTranspose2d):
|
|
|
|
# Depth-wise transpose convolution
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
|
|
|
|
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
|
|
|
|
|
|
|
|
|
|
|
|
class ConvTranspose(nn.Module):
|
|
|
|
# Convolution transpose 2d layer
|
|
|
|
default_act = nn.SiLU() # default activation
|
|
|
|
|
|
|
|
def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
|
|
|
|
super().__init__()
|
|
|
|
self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
|
|
|
|
self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
|
|
|
|
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.act(self.bn(self.conv_transpose(x)))
|
|
|
|
|
|
|
|
|
|
|
|
class DFL(nn.Module):
|
|
|
|
# DFL module
|
|
|
|
def __init__(self, c1=16):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
|
|
|
|
x = torch.arange(c1, dtype=torch.float)
|
|
|
|
self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
|
|
|
|
self.c1 = c1
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
b, c, a = x.shape # batch, channels, anchors
|
|
|
|
return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
|
|
|
|
# return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerLayer(nn.Module):
|
|
|
|
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
|
|
|
|
def __init__(self, c, num_heads):
|
|
|
|
super().__init__()
|
|
|
|
self.q = nn.Linear(c, c, bias=False)
|
|
|
|
self.k = nn.Linear(c, c, bias=False)
|
|
|
|
self.v = nn.Linear(c, c, bias=False)
|
|
|
|
self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
|
|
|
|
self.fc1 = nn.Linear(c, c, bias=False)
|
|
|
|
self.fc2 = nn.Linear(c, c, bias=False)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
|
|
|
|
x = self.fc2(self.fc1(x)) + x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
|
# Vision Transformer https://arxiv.org/abs/2010.11929
|
|
|
|
def __init__(self, c1, c2, num_heads, num_layers):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = None
|
|
|
|
if c1 != c2:
|
|
|
|
self.conv = Conv(c1, c2)
|
|
|
|
self.linear = nn.Linear(c2, c2) # learnable position embedding
|
|
|
|
self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
|
|
|
|
self.c2 = c2
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if self.conv is not None:
|
|
|
|
x = self.conv(x)
|
|
|
|
b, _, w, h = x.shape
|
|
|
|
p = x.flatten(2).permute(2, 0, 1)
|
|
|
|
return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
|
|
|
|
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
|
|
# Standard bottleneck
|
|
|
|
def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
|
|
|
|
super().__init__()
|
|
|
|
c_ = int(c2 * e) # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, k[0], 1)
|
|
|
|
self.cv2 = Conv(c_, c2, k[1], 1, g=g)
|
|
|
|
self.add = shortcut and c1 == c2
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
|
|
|
|
|
|
|
|
|
|
|
|
class BottleneckCSP(nn.Module):
|
|
|
|
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
c_ = int(c2 * e) # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
|
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
|
|
|
|
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
|
|
|
|
self.cv4 = Conv(2 * c_, c2, 1, 1)
|
|
|
|
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
|
|
|
|
self.act = nn.SiLU()
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
y1 = self.cv3(self.m(self.cv1(x)))
|
|
|
|
y2 = self.cv2(x)
|
|
|
|
return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
|
|
|
|
|
|
|
|
|
|
|
|
class C3(nn.Module):
|
|
|
|
# CSP Bottleneck with 3 convolutions
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
c_ = int(c2 * e) # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
|
self.cv2 = Conv(c1, c_, 1, 1)
|
|
|
|
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
|
|
|
|
|
|
|
|
|
|
|
|
class C2(nn.Module):
|
|
|
|
# CSP Bottleneck with 2 convolutions
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
self.c = int(c2 * e) # hidden channels
|
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
|
|
|
self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2)
|
|
|
|
# self.attention = ChannelAttention(2 * self.c) # or SpatialAttention()
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n)))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
a, b = self.cv1(x).split((self.c, self.c), 1)
|
|
|
|
return self.cv2(torch.cat((self.m(a), b), 1))
|
|
|
|
|
|
|
|
|
|
|
|
class C2f(nn.Module):
|
|
|
|
# CSP Bottleneck with 2 convolutions
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
self.c = int(c2 * e) # hidden channels
|
|
|
|
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
|
|
|
|
self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
|
|
|
|
self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
y = list(self.cv1(x).split((self.c, self.c), 1))
|
|
|
|
y.extend(m(y[-1]) for m in self.m)
|
|
|
|
return self.cv2(torch.cat(y, 1))
|
|
|
|
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
# Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet
|
|
|
|
def __init__(self, channels: int) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
|
|
|
|
self.act = nn.Sigmoid()
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
return x * self.act(self.fc(self.pool(x)))
|
|
|
|
|
|
|
|
|
|
|
|
class SpatialAttention(nn.Module):
|
|
|
|
# Spatial-attention module
|
|
|
|
def __init__(self, kernel_size=7):
|
|
|
|
super().__init__()
|
|
|
|
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
|
|
|
padding = 3 if kernel_size == 7 else 1
|
|
|
|
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
|
|
|
self.act = nn.Sigmoid()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))
|
|
|
|
|
|
|
|
|
|
|
|
class CBAM(nn.Module):
|
|
|
|
# CSP Bottleneck with 3 convolutions
|
|
|
|
def __init__(self, c1, ratio=16, kernel_size=7): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
self.channel_attention = ChannelAttention(c1)
|
|
|
|
self.spatial_attention = SpatialAttention(kernel_size)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.spatial_attention(self.channel_attention(x))
|
|
|
|
|
|
|
|
|
|
|
|
class C1(nn.Module):
|
|
|
|
# CSP Bottleneck with 3 convolutions
|
|
|
|
def __init__(self, c1, c2, n=1): # ch_in, ch_out, number, shortcut, groups, expansion
|
|
|
|
super().__init__()
|
|
|
|
self.cv1 = Conv(c1, c2, 1, 1)
|
|
|
|
self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n)))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
y = self.cv1(x)
|
|
|
|
return self.m(y) + y
|
|
|
|
|
|
|
|
|
|
|
|
class C3x(C3):
|
|
|
|
# C3 module with cross-convolutions
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
|
|
|
super().__init__(c1, c2, n, shortcut, g, e)
|
|
|
|
self.c_ = int(c2 * e)
|
|
|
|
self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n)))
|
|
|
|
|
|
|
|
|
|
|
|
class C3TR(C3):
|
|
|
|
# C3 module with TransformerBlock()
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
|
|
|
super().__init__(c1, c2, n, shortcut, g, e)
|
|
|
|
c_ = int(c2 * e)
|
|
|
|
self.m = TransformerBlock(c_, c_, 4, n)
|
|
|
|
|
|
|
|
|
|
|
|
class C3Ghost(C3):
|
|
|
|
# C3 module with GhostBottleneck()
|
|
|
|
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
|
|
|
|
super().__init__(c1, c2, n, shortcut, g, e)
|
|
|
|
c_ = int(c2 * e) # hidden channels
|
|
|
|
self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
|
|
|
|
|
|
|
|
|
|
|
|
class SPP(nn.Module):
|
|
|
|
# Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
|
|
|
|
def __init__(self, c1, c2, k=(5, 9, 13)):
|
|
|
|
super().__init__()
|
|
|
|
c_ = c1 // 2 # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
|
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
|
|
|
|
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.cv1(x)
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
|
|
|
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
|
|
|
|
|
|
|
|
|
|
|
|
class SPPF(nn.Module):
|
|
|
|
# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
|
|
|
|
def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
|
|
|
|
super().__init__()
|
|
|
|
c_ = c1 // 2 # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, 1, 1)
|
|
|
|
self.cv2 = Conv(c_ * 4, c2, 1, 1)
|
|
|
|
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.cv1(x)
|
|
|
|
with warnings.catch_warnings():
|
|
|
|
warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
|
|
|
|
y1 = self.m(x)
|
|
|
|
y2 = self.m(y1)
|
|
|
|
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
|
|
|
|
|
|
|
|
|
|
|
|
class Focus(nn.Module):
|
|
|
|
# Focus wh information into c-space
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
|
|
|
|
super().__init__()
|
|
|
|
self.conv = Conv(c1 * 4, c2, k, s, p, g, act=act)
|
|
|
|
# self.contract = Contract(gain=2)
|
|
|
|
|
|
|
|
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
|
|
|
|
return self.conv(torch.cat((x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]), 1))
|
|
|
|
# return self.conv(self.contract(x))
|
|
|
|
|
|
|
|
|
|
|
|
class GhostConv(nn.Module):
|
|
|
|
# Ghost Convolution https://github.com/huawei-noah/ghostnet
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
|
|
|
|
super().__init__()
|
|
|
|
c_ = c2 // 2 # hidden channels
|
|
|
|
self.cv1 = Conv(c1, c_, k, s, None, g, act=act)
|
|
|
|
self.cv2 = Conv(c_, c_, 5, 1, None, c_, act=act)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
y = self.cv1(x)
|
|
|
|
return torch.cat((y, self.cv2(y)), 1)
|
|
|
|
|
|
|
|
|
|
|
|
class GhostBottleneck(nn.Module):
|
|
|
|
# Ghost Bottleneck https://github.com/huawei-noah/ghostnet
|
|
|
|
def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
|
|
|
|
super().__init__()
|
|
|
|
c_ = c2 // 2
|
|
|
|
self.conv = nn.Sequential(
|
|
|
|
GhostConv(c1, c_, 1, 1), # pw
|
|
|
|
DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
|
|
|
|
GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
|
|
|
|
self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1,
|
|
|
|
act=False)) if s == 2 else nn.Identity()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.conv(x) + self.shortcut(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Concat(nn.Module):
|
|
|
|
# Concatenate a list of tensors along dimension
|
|
|
|
def __init__(self, dimension=1):
|
|
|
|
super().__init__()
|
|
|
|
self.d = dimension
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return torch.cat(x, self.d)
|
|
|
|
|
|
|
|
|
|
|
|
class AutoShape(nn.Module):
|
|
|
|
# YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
|
|
|
|
conf = 0.25 # NMS confidence threshold
|
|
|
|
iou = 0.45 # NMS IoU threshold
|
|
|
|
agnostic = False # NMS class-agnostic
|
|
|
|
multi_label = False # NMS multiple labels per box
|
|
|
|
classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
|
|
|
|
max_det = 1000 # maximum number of detections per image
|
|
|
|
amp = False # Automatic Mixed Precision (AMP) inference
|
|
|
|
|
|
|
|
def __init__(self, model, verbose=True):
|
|
|
|
super().__init__()
|
|
|
|
if verbose:
|
|
|
|
LOGGER.info('Adding AutoShape... ')
|
|
|
|
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
|
|
|
|
self.dmb = isinstance(model, AutoBackend) # DetectMultiBackend() instance
|
|
|
|
self.pt = not self.dmb or model.pt # PyTorch model
|
|
|
|
self.model = model.eval()
|
|
|
|
if self.pt:
|
|
|
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
|
|
|
m.inplace = False # Detect.inplace=False for safe multithread inference
|
|
|
|
m.export = True # do not output loss values
|
|
|
|
|
|
|
|
def _apply(self, fn):
|
|
|
|
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
|
|
|
|
self = super()._apply(fn)
|
|
|
|
if self.pt:
|
|
|
|
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
|
|
|
|
m.stride = fn(m.stride)
|
|
|
|
m.grid = list(map(fn, m.grid))
|
|
|
|
if isinstance(m.anchor_grid, list):
|
|
|
|
m.anchor_grid = list(map(fn, m.anchor_grid))
|
|
|
|
return self
|
|
|
|
|
|
|
|
@smart_inference_mode()
|
|
|
|
def forward(self, ims, size=640, augment=False, profile=False):
|
|
|
|
# Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
|
|
|
|
# file: ims = 'data/images/zidane.jpg' # str or PosixPath
|
|
|
|
# URI: = 'https://ultralytics.com/images/zidane.jpg'
|
|
|
|
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
|
|
|
|
# PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
|
|
|
|
# numpy: = np.zeros((640,1280,3)) # HWC
|
|
|
|
# torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
|
|
|
|
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
|
|
|
|
|
|
|
|
dt = (Profile(), Profile(), Profile())
|
|
|
|
with dt[0]:
|
|
|
|
if isinstance(size, int): # expand
|
|
|
|
size = (size, size)
|
|
|
|
p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
|
|
|
|
autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
|
|
|
|
if isinstance(ims, torch.Tensor): # torch
|
|
|
|
with amp.autocast(autocast):
|
|
|
|
return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
|
|
|
|
|
|
|
|
# Pre-process
|
|
|
|
n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
|
|
|
|
shape0, shape1, files = [], [], [] # image and inference shapes, filenames
|
|
|
|
for i, im in enumerate(ims):
|
|
|
|
f = f'image{i}' # filename
|
|
|
|
if isinstance(im, (str, Path)): # filename or uri
|
|
|
|
im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
|
|
|
|
im = np.asarray(ImageOps.exif_transpose(im))
|
|
|
|
elif isinstance(im, Image.Image): # PIL Image
|
|
|
|
im, f = np.asarray(ImageOps.exif_transpose(im)), getattr(im, 'filename', f) or f
|
|
|
|
files.append(Path(f).with_suffix('.jpg').name)
|
|
|
|
if im.shape[0] < 5: # image in CHW
|
|
|
|
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
|
|
|
|
im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
|
|
|
|
s = im.shape[:2] # HWC
|
|
|
|
shape0.append(s) # image shape
|
|
|
|
g = max(size) / max(s) # gain
|
|
|
|
shape1.append([y * g for y in s])
|
|
|
|
ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
|
|
|
|
shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
|
|
|
|
x = [LetterBox(shape1, auto=False)(image=im)["img"] for im in ims] # pad
|
|
|
|
x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
|
|
|
|
x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
|
|
|
|
|
|
|
|
with amp.autocast(autocast):
|
|
|
|
# Inference
|
|
|
|
with dt[1]:
|
|
|
|
y = self.model(x, augment=augment) # forward
|
|
|
|
|
|
|
|
# Post-process
|
|
|
|
with dt[2]:
|
|
|
|
y = non_max_suppression(y if self.dmb else y[0],
|
|
|
|
self.conf,
|
|
|
|
self.iou,
|
|
|
|
self.classes,
|
|
|
|
self.agnostic,
|
|
|
|
self.multi_label,
|
|
|
|
max_det=self.max_det) # NMS
|
|
|
|
for i in range(n):
|
|
|
|
scale_boxes(shape1, y[i][:, :4], shape0[i])
|
|
|
|
|
|
|
|
return Detections(ims, y, files, dt, self.names, x.shape)
|
|
|
|
|
|
|
|
|
|
|
|
class Detections:
|
|
|
|
# YOLOv5 detections class for inference results
|
|
|
|
def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
|
|
|
|
super().__init__()
|
|
|
|
d = pred[0].device # device
|
|
|
|
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
|
|
|
|
self.ims = ims # list of images as numpy arrays
|
|
|
|
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
|
|
|
|
self.names = names # class names
|
|
|
|
self.files = files # image filenames
|
|
|
|
self.times = times # profiling times
|
|
|
|
self.xyxy = pred # xyxy pixels
|
|
|
|
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
|
|
|
|
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
|
|
|
|
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
|
|
|
|
self.n = len(self.pred) # number of images (batch size)
|
|
|
|
self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
|
|
|
|
self.s = tuple(shape) # inference BCHW shape
|
|
|
|
|
|
|
|
def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
|
|
|
|
s, crops = '', []
|
|
|
|
for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
|
|
|
|
s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
|
|
|
|
if pred.shape[0]:
|
|
|
|
for c in pred[:, -1].unique():
|
|
|
|
n = (pred[:, -1] == c).sum() # detections per class
|
|
|
|
s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
|
|
|
|
s = s.rstrip(', ')
|
|
|
|
if show or save or render or crop:
|
|
|
|
annotator = Annotator(im, example=str(self.names))
|
|
|
|
for *box, conf, cls in reversed(pred): # xyxy, confidence, class
|
|
|
|
label = f'{self.names[int(cls)]} {conf:.2f}'
|
|
|
|
if crop:
|
|
|
|
file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
|
|
|
|
crops.append({
|
|
|
|
'box': box,
|
|
|
|
'conf': conf,
|
|
|
|
'cls': cls,
|
|
|
|
'label': label,
|
|
|
|
'im': save_one_box(box, im, file=file, save=save)})
|
|
|
|
else: # all others
|
|
|
|
annotator.box_label(box, label if labels else '', color=colors(cls))
|
|
|
|
im = annotator.im
|
|
|
|
else:
|
|
|
|
s += '(no detections)'
|
|
|
|
|
|
|
|
im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
|
|
|
|
if show:
|
|
|
|
im.show(self.files[i]) # show
|
|
|
|
if save:
|
|
|
|
f = self.files[i]
|
|
|
|
im.save(save_dir / f) # save
|
|
|
|
if i == self.n - 1:
|
|
|
|
LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
|
|
|
|
if render:
|
|
|
|
self.ims[i] = np.asarray(im)
|
|
|
|
if pprint:
|
|
|
|
s = s.lstrip('\n')
|
|
|
|
return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
|
|
|
|
if crop:
|
|
|
|
if save:
|
|
|
|
LOGGER.info(f'Saved results to {save_dir}\n')
|
|
|
|
return crops
|
|
|
|
|
|
|
|
def show(self, labels=True):
|
|
|
|
self._run(show=True, labels=labels) # show results
|
|
|
|
|
|
|
|
def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
|
|
|
|
save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
|
|
|
|
self._run(save=True, labels=labels, save_dir=save_dir) # save results
|
|
|
|
|
|
|
|
def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
|
|
|
|
save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
|
|
|
|
return self._run(crop=True, save=save, save_dir=save_dir) # crop results
|
|
|
|
|
|
|
|
def render(self, labels=True):
|
|
|
|
self._run(render=True, labels=labels) # render results
|
|
|
|
return self.ims
|
|
|
|
|
|
|
|
def pandas(self):
|
|
|
|
# return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
|
|
|
|
new = copy(self) # return copy
|
|
|
|
ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
|
|
|
|
cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
|
|
|
|
for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
|
|
|
|
a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
|
|
|
|
setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
|
|
|
|
return new
|
|
|
|
|
|
|
|
def tolist(self):
|
|
|
|
# return a list of Detections objects, i.e. 'for result in results.tolist():'
|
|
|
|
r = range(self.n) # iterable
|
|
|
|
x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
|
|
|
|
# for d in x:
|
|
|
|
# for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
|
|
|
|
# setattr(d, k, getattr(d, k)[0]) # pop out of list
|
|
|
|
return x
|
|
|
|
|
|
|
|
def print(self):
|
|
|
|
LOGGER.info(self.__str__())
|
|
|
|
|
|
|
|
def __len__(self): # override len(results)
|
|
|
|
return self.n
|
|
|
|
|
|
|
|
def __str__(self): # override print(results)
|
|
|
|
return self._run(pprint=True) # print results
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f'YOLOv5 {self.__class__} instance\n' + self.__str__()
|
|
|
|
|
|
|
|
|
|
|
|
class Proto(nn.Module):
|
|
|
|
# YOLOv5 mask Proto module for segmentation models
|
|
|
|
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
|
|
|
super().__init__()
|
|
|
|
self.cv1 = Conv(c1, c_, k=3)
|
|
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
|
|
self.cv2 = Conv(c_, c_, k=3)
|
|
|
|
self.cv3 = Conv(c_, c2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return self.cv3(self.cv2(self.upsample(self.cv1(x))))
|
|
|
|
|
|
|
|
|
|
|
|
class Ensemble(nn.ModuleList):
|
|
|
|
# Ensemble of models
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
def forward(self, x, augment=False, profile=False, visualize=False):
|
|
|
|
y = [module(x, augment, profile, visualize)[0] for module in self]
|
|
|
|
# y = torch.stack(y).max(0)[0] # max ensemble
|
|
|
|
# y = torch.stack(y).mean(0) # mean ensemble
|
|
|
|
y = torch.cat(y, 1) # nms ensemble
|
|
|
|
return y, None # inference, train output
|
|
|
|
|
|
|
|
|
|
|
|
# heads
|
|
|
|
class Detect(nn.Module):
|
|
|
|
# YOLOv5 Detect head for detection models
|
|
|
|
stride = None # strides computed during build
|
|
|
|
dynamic = False # force grid reconstruction
|
|
|
|
export = False # export mode
|
|
|
|
|
|
|
|
def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
|
|
|
|
super().__init__()
|
|
|
|
self.nc = nc # number of classes
|
|
|
|
self.no = nc + 5 # number of outputs per anchor
|
|
|
|
self.nl = len(anchors) # number of detection layers
|
|
|
|
self.na = len(anchors[0]) // 2 # number of anchors
|
|
|
|
self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
|
|
|
|
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
|
|
|
|
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
|
|
|
|
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
|
|
self.inplace = inplace # use inplace ops (e.g. slice assignment)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
z = [] # inference output
|
|
|
|
for i in range(self.nl):
|
|
|
|
x[i] = self.m[i](x[i]) # conv
|
|
|
|
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
|
|
|
|
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
|
|
|
|
|
|
|
|
if not self.training: # inference
|
|
|
|
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
|
|
|
|
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
|
|
|
|
|
|
|
|
if isinstance(self, Segment): # (boxes + masks)
|
|
|
|
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
|
|
|
|
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
|
|
|
|
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
|
|
|
|
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
|
|
|
|
else: # Detect (boxes only)
|
|
|
|
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
|
|
|
|
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
|
|
|
|
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
|
|
|
|
y = torch.cat((xy, wh, conf), 4)
|
|
|
|
z.append(y.view(bs, self.na * nx * ny, self.no))
|
|
|
|
|
|
|
|
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
|
|
|
|
|
|
|
|
def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
|
|
|
|
d = self.anchors[i].device
|
|
|
|
t = self.anchors[i].dtype
|
|
|
|
shape = 1, self.na, ny, nx, 2 # grid shape
|
|
|
|
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
|
|
|
|
yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
|
|
|
|
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
|
|
|
|
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
|
|
|
|
return grid, anchor_grid
|
|
|
|
|
|
|
|
|
|
|
|
class Segment(Detect):
|
|
|
|
# YOLOv5 Segment head for segmentation models
|
|
|
|
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
|
|
|
|
super().__init__(nc, anchors, ch, inplace)
|
|
|
|
self.nm = nm # number of masks
|
|
|
|
self.npr = npr # number of protos
|
|
|
|
self.no = 5 + nc + self.nm # number of outputs per anchor
|
|
|
|
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
|
|
|
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
|
|
|
self.detect = Detect.forward
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
p = self.proto(x[0])
|
|
|
|
x = self.detect(self, x)
|
|
|
|
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
|
|
|
|
|
|
|
|
|
|
|
|
class Classify(nn.Module):
|
|
|
|
# YOLOv5 classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
|
|
|
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
|
|
|
super().__init__()
|
|
|
|
c_ = 1280 # efficientnet_b0 size
|
|
|
|
self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
|
|
|
|
self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
|
|
|
|
self.drop = nn.Dropout(p=0.0, inplace=True)
|
|
|
|
self.linear = nn.Linear(c_, c2) # to x(b,c2)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if isinstance(x, list):
|
|
|
|
x = torch.cat(x, 1)
|
|
|
|
return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
|