Fix model re-fuse() in inference loops (#466)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -63,7 +63,8 @@ class BaseModel(nn.Module):
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
"""
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input. Appends the results to the provided list.
|
||||
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
||||
Appends the results to the provided list.
|
||||
|
||||
Args:
|
||||
m (nn.Module): The layer to be profiled.
|
||||
@ -74,10 +75,10 @@ class BaseModel(nn.Module):
|
||||
None
|
||||
"""
|
||||
c = m == self.model[-1] # is final layer, copy input as inplace fix
|
||||
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
o = thop.profile(m, inputs=(x.clone() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
for _ in range(10):
|
||||
m(x.copy() if c else x)
|
||||
m(x.clone() if c else x)
|
||||
dt.append((time_sync() - t) * 100)
|
||||
if m == self.model[0]:
|
||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||
@ -87,20 +88,36 @@ class BaseModel(nn.Module):
|
||||
|
||||
def fuse(self):
|
||||
"""
|
||||
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency.
|
||||
Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the
|
||||
computation efficiency.
|
||||
|
||||
Returns:
|
||||
(nn.Module): The fused model is returned.
|
||||
"""
|
||||
LOGGER.info('Fusing layers... ')
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
self.info()
|
||||
if not self.is_fused():
|
||||
LOGGER.info('Fusing... ')
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
self.info()
|
||||
|
||||
return self
|
||||
|
||||
def is_fused(self, thresh=10):
|
||||
"""
|
||||
Check if the model has less than a certain threshold of BatchNorm layers.
|
||||
|
||||
Args:
|
||||
thresh (int, optional): The threshold number of BatchNorm layers. Default is 10.
|
||||
|
||||
Returns:
|
||||
bool: True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
"""
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, verbose=False, imgsz=640):
|
||||
"""
|
||||
Prints model information
|
||||
|
Reference in New Issue
Block a user