Integration of v8 segmentation (#107)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -576,11 +576,11 @@ class Detections:
|
||||
|
||||
|
||||
class Proto(nn.Module):
|
||||
# YOLOv5 mask Proto module for segmentation models
|
||||
# YOLOv8 mask Proto module for segmentation models
|
||||
def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
|
||||
super().__init__()
|
||||
self.cv1 = Conv(c1, c_, k=3)
|
||||
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.cv2 = Conv(c_, c_, k=3)
|
||||
self.cv3 = Conv(c_, c2)
|
||||
|
||||
@ -628,16 +628,16 @@ class Detect(nn.Module):
|
||||
shape = x[0].shape # BCHW
|
||||
for i in range(self.nl):
|
||||
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
|
||||
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
||||
if self.training:
|
||||
return x, box, cls
|
||||
return x
|
||||
elif self.dynamic or self.shape != shape:
|
||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
|
||||
y = torch.cat((dbox, cls.sigmoid()), 1)
|
||||
return y if self.export else (y, (x, box, cls))
|
||||
return y if self.export else (y, x)
|
||||
|
||||
def bias_init(self):
|
||||
# Initialize Detect() biases, WARNING: requires stride availability
|
||||
@ -651,19 +651,27 @@ class Detect(nn.Module):
|
||||
|
||||
class Segment(Detect):
|
||||
# YOLOv5 Segment head for segmentation models
|
||||
def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=()):
|
||||
super().__init__(nc, anchors, ch)
|
||||
def __init__(self, nc=80, nm=32, npr=256, ch=()):
|
||||
super().__init__(nc, ch)
|
||||
self.nm = nm # number of masks
|
||||
self.npr = npr # number of protos
|
||||
self.no = 5 + nc + self.nm # number of outputs per anchor
|
||||
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
|
||||
self.proto = Proto(ch[0], self.npr, self.nm) # protos
|
||||
self.detect = Detect.forward
|
||||
|
||||
c4 = max(ch[0] // 4, self.nm)
|
||||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||
|
||||
def forward(self, x):
|
||||
p = self.proto(x[0])
|
||||
|
||||
mc = [] # mask coefficient
|
||||
for i in range(self.nl):
|
||||
mc.append(self.cv4[i](x[i]))
|
||||
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
|
||||
x = self.detect(self, x)
|
||||
return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
|
||||
|
||||
|
||||
class Classify(nn.Module):
|
||||
|
@ -101,7 +101,7 @@ class DetectionModel(BaseModel):
|
||||
if isinstance(m, (Detect, Segment)):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Detect)) else self.forward(x)
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
@ -163,8 +163,8 @@ class DetectionModel(BaseModel):
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv5 segmentation model
|
||||
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None):
|
||||
super().__init__(cfg, ch, nc)
|
||||
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg, ch, nc, verbose)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
@ -300,7 +300,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
elif m in {Detect, Segment}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[3] = make_divisible(args[3] * gw, 8)
|
||||
args[2] = make_divisible(args[2] * gw, 8)
|
||||
else:
|
||||
c2 = ch[f]
|
||||
|
||||
|
Reference in New Issue
Block a user