Fix model re-fuse() in inference loops (#466)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher
2023-01-18 20:32:36 +01:00
committed by GitHub
parent cc3c774bde
commit a86218b767
22 changed files with 135 additions and 66 deletions

View File

@ -63,7 +63,8 @@ class BaseModel(nn.Module):
def _profile_one_layer(self, m, x, dt):
"""
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list.
Profile the computation time and FLOPs of a single layer of the model on a given input.
Appends the results to the provided list.
Args:
m (nn.Module): The layer to be profiled.
@ -74,10 +75,10 @@ class BaseModel(nn.Module):
None
"""
c = m == self.model[-1] # is final layer, copy input as inplace fix
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
t = time_sync()
for _ in range(10):
m(x.copy() if c else x)
m(x.clone() if c else x)
dt.append((time_sync() - t) * 100)
if m == self.model[0]:
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
@ -87,20 +88,36 @@ class BaseModel(nn.Module):
def fuse(self):
"""
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency.
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
computation efficiency.
Returns:
(nn.Module): The fused model is returned.
"""
LOGGER.info('Fusing layers... ')
for m in self.model.modules():
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
self.info()
if not self.is_fused():
LOGGER.info('Fusing... ')
for m in self.model.modules():
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, 'bn') # remove batchnorm
m.forward = m.forward_fuse # update forward
self.info()
return self
def is_fused(self, thresh=10):
"""
Check if the model has less than a certain threshold of BatchNorm layers.
Args:
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
Returns:
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
"""
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, verbose=False, imgsz=640):
"""
Prints model information