Add Conv2() module (#2820)

single_channel
Glenn Jocher 2 years ago committed by GitHub
parent d19c5b6ce8
commit 441e67d330
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,16 +48,17 @@ trainer.train()
You now realize that you need to customize the trainer further to: You now realize that you need to customize the trainer further to:
* * Customize the `loss function`. * Customize the `loss function`.
* Add `callback` that uploads model to your Google Drive after every 10 `epochs` * Add `callback` that uploads model to your Google Drive after every 10 `epochs`
Here's how you can do it: Here's how you can do it:
```python ```python
from ultralytics.yolo.v8.detect import DetectionTrainer from ultralytics.yolo.v8.detect import DetectionTrainer
from ultralytcs.nn.tasks import DetectionModel from ultralytics.nn.tasks import DetectionModel
class MyCustomModel(DetectionModel): class MyCustomModel(DetectionModel):
def init_criterion(): def init_criterion(self):
... ...
@ -65,6 +66,7 @@ class CustomTrainer(DetectionTrainer):
def get_model(self, cfg, weights): def get_model(self, cfg, weights):
return MyCustomModel(...) return MyCustomModel(...)
# callback to upload model weights # callback to upload model weights
def log_model(trainer): def log_model(trainer):
last_weight_path = trainer.last last_weight_path = trainer.last

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-x object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3-SPP object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3-tiny object detection model with P4-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv3 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov3
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv5 object detection model with P3-P6 outputs. For details see https://docs.ultralytics.com/models/yolov5
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv5 object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/yolov5
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/models/yolov6
# Parameters # Parameters
act: nn.ReLU() act: nn.ReLU()
@ -23,29 +23,31 @@ backbone:
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16 - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 18, Conv, [512, 3, 1]] - [-1, 18, Conv, [512, 3, 1]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32 - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [ -1, 9, Conv, [ 1024, 3, 1 ] ] - [-1, 6, Conv, [1024, 3, 1]]
- [-1, 1, SPPF, [1024, 5]] # 9 - [-1, 1, SPPF, [1024, 5]] # 9
# YOLOv6-3.0s head # YOLOv6-3.0s head
head: head:
- [-1, 1, Conv, [256, 1, 1]]
- [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]] - [-1, 1, nn.ConvTranspose2d, [256, 2, 2, 0]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4 - [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 1, Conv, [256, 3, 1]] - [-1, 1, Conv, [256, 3, 1]]
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 13 - [-1, 9, Conv, [256, 3, 1]] # 14
- [-1, 1, Conv, [128, 1, 1]]
- [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]] - [-1, 1, nn.ConvTranspose2d, [128, 2, 2, 0]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3 - [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 1, Conv, [128, 3, 1]] - [-1, 1, Conv, [128, 3, 1]]
- [ -1, 9, Conv, [ 128, 3, 1 ] ] # 17 - [-1, 9, Conv, [128, 3, 1]] # 19
- [-1, 1, Conv, [128, 3, 2]] - [-1, 1, Conv, [128, 3, 2]]
- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4 - [[-1, 15], 1, Concat, [1]] # cat head P4
- [-1, 1, Conv, [256, 3, 1]] - [-1, 1, Conv, [256, 3, 1]]
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 21 - [-1, 9, Conv, [256, 3, 1]] # 23
- [-1, 1, Conv, [256, 3, 2]] - [-1, 1, Conv, [256, 3, 2]]
- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5 - [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 1, Conv, [512, 3, 1]] - [-1, 1, Conv, [512, 3, 1]]
- [ -1, 9, Conv, [ 512, 3, 1 ] ] # 25 - [-1, 9, Conv, [512, 3, 1]] # 27
- [ [ 17, 21, 25 ], 1, Detect, [ nc ] ] # Detect(P3, P4, P5) - [[19, 23, 27], 1, Detect, [nc]] # Detect(P3, P4, P5)

@ -1,5 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P6 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect # YOLOv8-pose keypoints/pose estimation model. For Usage examples see https://docs.ultralytics.com/tasks/pose
# Parameters # Parameters
nc: 1 # number of classes nc: 1 # number of classes

@ -1,15 +1,28 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
"""
Ultralytics modules. Visualize with:
from ultralytics.nn.modules import *
import torch
import os
x = torch.ones(1, 128, 40, 40)
m = Conv(128, 128)
f = f'{m._get_name()}.onnx'
torch.onnx.export(m, x, f)
os.system(f'onnxsim {f} {f} && open {f}')
"""
from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck, from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
HGBlock, HGStem, Proto, RepC3) HGBlock, HGStem, Proto, RepC3)
from .conv import (CBAM, ChannelAttention, Concat, Conv, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv, from .conv import (CBAM, ChannelAttention, Concat, Conv, Conv2, ConvTranspose, DWConv, DWConvTranspose2d, Focus,
LightConv, RepConv, SpatialAttention) GhostConv, LightConv, RepConv, SpatialAttention)
from .head import Classify, Detect, Pose, RTDETRDecoder, Segment from .head import Classify, Detect, Pose, RTDETRDecoder, Segment
from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d, from .transformer import (AIFI, MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer, LayerNorm2d,
MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer) MLPBlock, MSDeformAttn, TransformerBlock, TransformerEncoderLayer, TransformerLayer)
__all__ = [ __all__ = [
'Conv', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', 'Conv', 'Conv2', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv',
'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock', 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock',
'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify',

@ -43,6 +43,27 @@ class Conv(nn.Module):
return self.act(self.conv(x)) return self.act(self.conv(x))
class Conv2(Conv):
"""Simplified RepConv module with Conv fusing."""
def __init__(self, c1, c2, k=3, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__(c1, c2, k, s, p, g=g, d=d, act=act)
self.cv2 = nn.Conv2d(c1, c2, 1, s, autopad(1, p, d), groups=g, dilation=d, bias=False) # add 1x1 conv
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x) + self.cv2(x)))
def fuse_convs(self):
"""Fuse parallel convolutions."""
w = torch.zeros_like(self.conv.weight.data)
i = [x // 2 for x in w.shape[2:]]
w[:, :, i[0] - 1:i[0], i[1] - 1:i[1]] = self.cv2.weight.data.clone()
self.conv.weight.data += w
self.__delattr__('cv2')
class LightConv(nn.Module): class LightConv(nn.Module):
"""Light convolution with args(ch_in, ch_out, kernel). """Light convolution with args(ch_in, ch_out, kernel).
https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py

@ -8,9 +8,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
Segment) RTDETRDecoder, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
@ -103,7 +103,9 @@ class BaseModel(nn.Module):
""" """
if not self.is_fused(): if not self.is_fused():
for m in self.model.modules(): for m in self.model.modules():
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'): if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
if isinstance(m, Conv2):
m.fuse_convs()
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward m.forward = m.forward_fuse # update forward

Loading…
Cancel
Save