diff --git a/ultralytics/nn/modules.py b/ultralytics/nn/modules.py index d24d225..b98ba65 100644 --- a/ultralytics/nn/modules.py +++ b/ultralytics/nn/modules.py @@ -62,6 +62,9 @@ class ConvTranspose(nn.Module): def forward(self, x): return self.act(self.bn(self.conv_transpose(x))) + def forward_fuse(self, x): + return self.act(self.conv_transpose(x)) + class DFL(nn.Module): # Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index f431246..183670a 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -12,8 +12,8 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot GhostBottleneck, GhostConv, Segment) from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load from ultralytics.yolo.utils.checks import check_requirements, check_yaml -from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible, - model_info, scale_img, time_sync) +from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, + intersect_dicts, make_divisible, model_info, scale_img, time_sync) class BaseModel(nn.Module): @@ -100,6 +100,10 @@ class BaseModel(nn.Module): m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv delattr(m, 'bn') # remove batchnorm m.forward = m.forward_fuse # update forward + if isinstance(m, ConvTranspose) and hasattr(m, 'bn'): + m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) + delattr(m, 'bn') # remove batchnorm + m.forward = m.forward_fuse # update forward self.info() return self diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index 342b0aa..7beffbd 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -135,6 +135,30 @@ def fuse_conv_and_bn(conv, bn): return fusedconv +def fuse_deconv_and_bn(deconv, bn): + fuseddconv = nn.ConvTranspose2d(deconv.in_channels, + deconv.out_channels, + kernel_size=deconv.kernel_size, + stride=deconv.stride, + padding=deconv.padding, + output_padding=deconv.output_padding, + dilation=deconv.dilation, + groups=deconv.groups, + bias=True).requires_grad_(False).to(deconv.weight.device) + + # prepare filters + w_deconv = deconv.weight.clone().view(deconv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape)) + + # Prepare spatial bias + b_conv = torch.zeros(deconv.weight.size(1), device=deconv.weight.device) if deconv.bias is None else deconv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fuseddconv + + def model_info(model, verbose=False, imgsz=640): # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] n_p = get_num_params(model)