Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 00:34:34 +05:30
committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
14 changed files with 199 additions and 71 deletions

View File

@ -662,12 +662,10 @@ class Segment(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
def forward(self, x):
p = self.proto(x[0])
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = [] # mask coefficient
for i in range(self.nl):
mc.append(self.cv4[i](x[i]))
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
x = self.detect(self, x)
if self.training:
return x, mc, p

View File

@ -1,11 +1,9 @@
import contextlib
from copy import deepcopy
from pathlib import Path
import thop
import torch
import torch.nn as nn
import torchvision
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
@ -226,9 +224,15 @@ class SegmentationModel(DetectionModel):
class ClassificationModel(BaseModel):
# YOLOv5 classification model
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
def __init__(self,
cfg=None,
model=None,
ch=3,
nc=1000,
cutoff=10,
verbose=True): # yaml, model, number of classes, cutoff index
super().__init__()
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
def _from_detection_model(self, model, nc=1000, cutoff=10):
# Create a YOLOv5 classification model from a YOLOv5 detection model
@ -246,9 +250,15 @@ class ClassificationModel(BaseModel):
self.save = []
self.nc = nc
def _from_yaml(self, cfg):
# TODO: Create a YOLOv5 classification model from a *.yaml file
self.model = None
def _from_yaml(self, cfg, ch, nc, verbose):
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
# Define model
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
if nc and nc != self.yaml['nc']:
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
self.yaml['nc'] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
def load(self, weights):
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
@ -351,7 +361,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
# Parse a YOLO model.yaml dictionary
if verbose:
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
@ -359,7 +369,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print
no = nc + 4 # number of outputs = classes + box
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
@ -370,10 +379,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in {
Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP,
C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
c1, c2 = ch[f], args[0]
if c2 != no: # if not output
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
c2 = make_divisible(c2 * gw, 8)
args = [c1, c2, *args[1:]]
@ -384,7 +393,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
# TODO: channel, gw, gd
elif m in {Detect, Segment}:
args.append([ch[x] for x in f])
if m is Segment: