Integration of v8 segmentation (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2022-12-28 23:01:38 +08:00
committed by GitHub
parent 384f0ef1c6
commit 8406b49b49
16 changed files with 422 additions and 224 deletions

View File

@ -576,11 +576,11 @@ class Detections:
class Proto(nn.Module):
# YOLOv5 mask Proto module for segmentation models
# YOLOv8 mask Proto module for segmentation models
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
super().__init__()
self.cv1 = Conv(c1, c_, k=3)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
self.cv2 = Conv(c_, c_, k=3)
self.cv3 = Conv(c_, c2)
@ -628,16 +628,16 @@ class Detect(nn.Module):
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
if self.training:
return x, box, cls
return x
elif self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, (x, box, cls))
return y if self.export else (y, x)
def bias_init(self):
# Initialize Detect() biases, WARNING: requires stride availability
@ -651,19 +651,27 @@ class Detect(nn.Module):
class Segment(Detect):
# YOLOv5 Segment head for segmentation models
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=()):
super().__init__(nc, anchors, ch)
def __init__(self, nc=80, nm=32, npr=256, ch=()):
super().__init__(nc, ch)
self.nm = nm # number of masks
self.npr = npr # number of protos
self.no = 5 + nc + self.nm # number of outputs per anchor
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
self.proto = Proto(ch[0], self.npr, self.nm) # protos
self.detect = Detect.forward
c4 = max(ch[0] // 4, self.nm)
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
def forward(self, x):
p = self.proto(x[0])
mc = [] # mask coefficient
for i in range(self.nl):
mc.append(self.cv4[i](x[i]))
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
x = self.detect(self, x)
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
if self.training:
return x, mc, p
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
class Classify(nn.Module):

View File

@ -101,7 +101,7 @@ class DetectionModel(BaseModel):
if isinstance(m, (Detect, Segment)):
s = 256 # 2x min stride
m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Detect)) else self.forward(x)
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.bias_init() # only run once
@ -163,8 +163,8 @@ class DetectionModel(BaseModel):
class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None):
super().__init__(cfg, ch, nc)
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg, ch, nc, verbose)
class ClassificationModel(BaseModel):
@ -300,7 +300,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
elif m in {Detect, Segment}:
args.append([ch[x] for x in f])
if m is Segment:
args[3] = make_divisible(args[3] * gw, 8)
args[2] = make_divisible(args[2] * gw, 8)
else:
c2 = ch[f]