|
|
|
@ -34,7 +34,7 @@ class Detect(nn.Module):
|
|
|
|
|
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
|
|
|
|
|
self.no = nc + self.reg_max * 4 # number of outputs per anchor
|
|
|
|
|
self.stride = torch.zeros(self.nl) # strides computed during build
|
|
|
|
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], self.nc) # channels
|
|
|
|
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
|
|
|
|
|
self.cv2 = nn.ModuleList(
|
|
|
|
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
|
|
|
|
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
|
|
|
|
|