Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -662,12 +662,10 @@ class Segment(Detect):
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||
|
||||
def forward(self, x):
|
||||
p = self.proto(x[0])
|
||||
p = self.proto(x[0]) # mask protos
|
||||
bs = p.shape[0] # batch size
|
||||
|
||||
mc = [] # mask coefficient
|
||||
for i in range(self.nl):
|
||||
mc.append(self.cv4[i](x[i]))
|
||||
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
|
||||
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||
x = self.detect(self, x)
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
|
@ -1,11 +1,9 @@
|
||||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import thop
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
@ -226,9 +224,15 @@ class SegmentationModel(DetectionModel):
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv5 classification model
|
||||
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
||||
def __init__(self,
|
||||
cfg=None,
|
||||
model=None,
|
||||
ch=3,
|
||||
nc=1000,
|
||||
cutoff=10,
|
||||
verbose=True): # yaml, model, number of classes, cutoff index
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||
@ -246,9 +250,15 @@ class ClassificationModel(BaseModel):
|
||||
self.save = []
|
||||
self.nc = nc
|
||||
|
||||
def _from_yaml(self, cfg):
|
||||
# TODO: Create a YOLOv5 classification model from a *.yaml file
|
||||
self.model = None
|
||||
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
|
||||
# Define model
|
||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
|
||||
def load(self, weights):
|
||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
@ -351,7 +361,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
|
||||
|
||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
# Parse a YOLOv5 model.yaml dictionary
|
||||
# Parse a YOLO model.yaml dictionary
|
||||
if verbose:
|
||||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
||||
@ -359,7 +369,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||
if verbose:
|
||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||
no = nc + 4 # number of outputs = classes + box
|
||||
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||
@ -370,10 +379,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
|
||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP,
|
||||
C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != no: # if not output
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(c2 * gw, 8)
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
@ -384,7 +393,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
# TODO: channel, gw, gd
|
||||
elif m in {Detect, Segment}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
|
Reference in New Issue
Block a user