ultralytics 8.0.65
YOLOv8 Pose models (#1347)
Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Mert Can Demir <validatedev@gmail.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Fabian Greavu <fabiangreavu@gmail.com> Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com> Co-authored-by: Eric Pedley <ericpedley@gmail.com> Co-authored-by: JustasBart <40023722+JustasBart@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Aarni Koskela <akx@iki.fi> Co-authored-by: Sergio Sanchez <sergio.ssm.97@gmail.com> Co-authored-by: Bogdan Gheorghe <112427971+bogdan-galileo@users.noreply.github.com> Co-authored-by: Jaap van de Loosdrecht <jaap@vdlmv.nl> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: nerdyespresso <106761627+nerdyespresso@users.noreply.github.com> Co-authored-by: Farid Inawan <frdteknikelektro@gmail.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Alexander Duda <Alexander.Duda@me.com> Co-authored-by: Mehran Ghandehari <mehran.maps@gmail.com> Co-authored-by: Snyk bot <snyk-bot@snyk.io> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
@ -91,8 +91,10 @@ class AutoBackend(nn.Module):
|
||||
if nn_module:
|
||||
model = weights.to(device)
|
||||
model = model.fuse(verbose=verbose) if fuse else model
|
||||
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
||||
if hasattr(model, 'kpt_shape'):
|
||||
kpt_shape = model.kpt_shape # pose-only
|
||||
stride = max(int(model.stride.max()), 32) # model stride
|
||||
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
||||
model.half() if fp16 else model.float()
|
||||
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
|
||||
pt = True
|
||||
@ -102,6 +104,8 @@ class AutoBackend(nn.Module):
|
||||
device=device,
|
||||
inplace=True,
|
||||
fuse=fuse)
|
||||
if hasattr(model, 'kpt_shape'):
|
||||
kpt_shape = model.kpt_shape # pose-only
|
||||
stride = max(int(model.stride.max()), 32) # model stride
|
||||
names = model.module.names if hasattr(model, 'module') else model.names # get class names
|
||||
model.half() if fp16 else model.float()
|
||||
@ -268,13 +272,14 @@ class AutoBackend(nn.Module):
|
||||
for k, v in metadata.items():
|
||||
if k in ('stride', 'batch'):
|
||||
metadata[k] = int(v)
|
||||
elif k in ('imgsz', 'names') and isinstance(v, str):
|
||||
elif k in ('imgsz', 'names', 'kpt_shape') and isinstance(v, str):
|
||||
metadata[k] = eval(v)
|
||||
stride = metadata['stride']
|
||||
task = metadata['task']
|
||||
batch = metadata['batch']
|
||||
imgsz = metadata['imgsz']
|
||||
names = metadata['names']
|
||||
kpt_shape = metadata.get('kpt_shape')
|
||||
elif not (pt or triton or nn_module):
|
||||
LOGGER.warning(f"WARNING ⚠️ Metadata not found for 'model={weights}'")
|
||||
|
||||
|
@ -378,7 +378,9 @@ class Ensemble(nn.ModuleList):
|
||||
return y, None # inference, train output
|
||||
|
||||
|
||||
# heads
|
||||
# Model heads below ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Detect(nn.Module):
|
||||
# YOLOv8 Detect head for detection models
|
||||
dynamic = False # force grid reconstruction
|
||||
@ -394,7 +396,6 @@ class Detect(nn.Module):
|
||||
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
||||
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||
|
||||
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
||||
self.cv2 = nn.ModuleList(
|
||||
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
||||
@ -454,6 +455,36 @@ class Segment(Detect):
|
||||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||||
|
||||
|
||||
class Pose(Detect):
|
||||
# YOLOv8 Pose head for keypoints models
|
||||
def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
|
||||
super().__init__(nc, ch)
|
||||
self.kpt_shape = kpt_shape # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)
|
||||
self.nk = kpt_shape[0] * kpt_shape[1] # number of keypoints total
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.nk)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nk, 1)) for x in ch)
|
||||
|
||||
def forward(self, x):
|
||||
bs = x[0].shape[0] # batch size
|
||||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||||
x = self.detect(self, x)
|
||||
if self.training:
|
||||
return x, kpt
|
||||
pred_kpt = self.kpts_decode(kpt)
|
||||
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))
|
||||
|
||||
def kpts_decode(self, kpts):
|
||||
ndim = self.kpt_shape[1]
|
||||
y = kpts.clone()
|
||||
if ndim == 3:
|
||||
y[:, 2::3].sigmoid_() # inplace sigmoid
|
||||
y[:, 0::ndim] = (y[:, 0::ndim] * 2.0 + (self.anchors[0] - 0.5)) * self.strides
|
||||
y[:, 1::ndim] = (y[:, 1::ndim] * 2.0 + (self.anchors[1] - 0.5)) * self.strides
|
||||
return y
|
||||
|
||||
|
||||
class Classify(nn.Module):
|
||||
# YOLOv8 classification head, i.e. x(b,c1,20,20) to x(b,c2)
|
||||
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
|
||||
|
@ -10,7 +10,7 @@ import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
GhostBottleneck, GhostConv, Segment)
|
||||
GhostBottleneck, GhostConv, Pose, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
@ -183,10 +183,10 @@ class DetectionModel(BaseModel):
|
||||
|
||||
# Build strides
|
||||
m = self.model[-1] # Detect()
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
if isinstance(m, (Detect, Segment, Pose)):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
@ -242,12 +242,23 @@ class DetectionModel(BaseModel):
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv8 segmentation model
|
||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg, ch, nc, verbose)
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def _forward_augment(self, x):
|
||||
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
||||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
# YOLOv8 pose model
|
||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = yaml_model_load(cfg) # load model YAML
|
||||
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
||||
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
||||
cfg['kpt_shape'] = data_kpt_shape
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv8 classification model
|
||||
def __init__(self,
|
||||
@ -425,7 +436,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
# Args
|
||||
max_channels = float('inf')
|
||||
nc, act, scales = (d.get(x) for x in ('nc', 'act', 'scales'))
|
||||
depth, width = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple'))
|
||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
||||
if scales:
|
||||
scale = d.get('scale')
|
||||
if not scale:
|
||||
@ -464,7 +475,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
elif m in (Detect, Segment):
|
||||
elif m in (Detect, Segment, Pose):
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[2] = make_divisible(min(args[2], max_channels) * width, 8)
|
||||
@ -543,6 +554,8 @@ def guess_model_task(model):
|
||||
return 'detect'
|
||||
if m == 'segment':
|
||||
return 'segment'
|
||||
if m == 'pose':
|
||||
return 'pose'
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
@ -565,6 +578,8 @@ def guess_model_task(model):
|
||||
return 'segment'
|
||||
elif isinstance(m, Classify):
|
||||
return 'classify'
|
||||
elif isinstance(m, Pose):
|
||||
return 'pose'
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
@ -573,10 +588,12 @@ def guess_model_task(model):
|
||||
return 'segment'
|
||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||
return 'classify'
|
||||
elif '-pose' in model.stem or 'pose' in model.parts:
|
||||
return 'pose'
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
|
||||
# Unable to determine task from model
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'task=segment' or 'task=classify'.")
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify', or 'pose'.")
|
||||
return 'detect' # assume detect
|
||||
|
Reference in New Issue
Block a user