Task augment (#2924)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Laughing 2 years ago committed by GitHub
parent f4b34fc30b
commit c050b2d1a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,6 +54,22 @@ class BaseModel(nn.Module):
visualize (bool): Save the feature maps of the model if True, defaults to False.
augment (bool): Augment image during prediction, defaults to False.
Returns:
(torch.Tensor): The last output of the model.
"""
if augment:
return self._predict_augment(x)
return self._predict_once(x, profile, visualize)
def _predict_once(self, x, profile=False, visualize=False):
"""
Perform a forward pass through the network.
Args:
x (torch.Tensor): The input tensor to the model.
profile (bool): Print the computation time of each layer if True, defaults to False.
visualize (bool): Save the feature maps of the model if True, defaults to False.
Returns:
(torch.Tensor): The last output of the model.
"""
@ -69,6 +85,13 @@ class BaseModel(nn.Module):
feature_visualization(x, m.type, m.i, save_dir=visualize)
return x
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
LOGGER.warning(
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
)
return self._predict_once(x)
def _profile_one_layer(self, m, x, dt):
"""
Profile the computation time and FLOPs of a single layer of the model on a given input.
@ -225,13 +248,7 @@ class DetectionModel(BaseModel):
self.info()
LOGGER.info('')
def predict(self, x, augment=False, profile=False, visualize=False):
"""Run forward pass on input image(s) with optional augmentation and profiling."""
if augment:
return self._forward_augment(x) # augmented inference, None
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train
def _forward_augment(self, x):
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference and train outputs."""
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel):
"""Initialize YOLOv8 segmentation model with given config and parameters."""
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
def init_criterion(self):
return v8SegmentationLoss(self)
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
LOGGER.warning(
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
)
return self._predict_once(x)
class PoseModel(DetectionModel):
"""YOLOv8 pose model."""
@ -302,9 +322,12 @@ class PoseModel(DetectionModel):
def init_criterion(self):
return v8PoseLoss(self)
def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
def _predict_augment(self, x):
"""Perform augmentations on input image x and return augmented inference."""
LOGGER.warning(
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
)
return self._predict_once(x)
class ClassificationModel(BaseModel):
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel):
x = head([y[j] for j in head.f], batch) # head inference
return x
def _forward_augment(self, x):
"""Undocumented function."""
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
class Ensemble(nn.ModuleList):
"""Ensemble of models."""

Loading…
Cancel
Save