Support fuse-deconv-and-bn (#786)

This commit is contained in:
tripleMu
2023-02-04 03:50:25 +08:00
committed by GitHub
parent fa8811dcee
commit 5a80ad98db
3 changed files with 33 additions and 2 deletions

View File

@ -62,6 +62,9 @@ class ConvTranspose(nn.Module):
def forward(self, x):
return self.act(self.bn(self.conv_transpose(x)))
def forward_fuse(self, x):
return self.act(self.conv_transpose(x))
class DFL(nn.Module):
# Integral module of Distribution Focal Loss (DFL) proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391

View File

@ -12,8 +12,8 @@ from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, Bot
GhostBottleneck, GhostConv, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_yaml
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, initialize_weights, intersect_dicts, make_divisible,
model_info, scale_img, time_sync)
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
class BaseModel(nn.Module):
@ -100,6 +100,10 @@ class BaseModel(nn.Module):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
self.info()
return self