From c050b2d1a87f6cb43516243daec1cb3f95f0fb98 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Thu, 1 Jun 2023 06:41:10 +0800 Subject: [PATCH] Task augment (#2924) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/nn/tasks.py | 55 +++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index cfbc122..aaeef5b 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -54,6 +54,22 @@ class BaseModel(nn.Module): visualize (bool): Save the feature maps of the model if True, defaults to False. augment (bool): Augment image during prediction, defaults to False. + Returns: + (torch.Tensor): The last output of the model. + """ + if augment: + return self._predict_augment(x) + return self._predict_once(x, profile, visualize) + + def _predict_once(self, x, profile=False, visualize=False): + """ + Perform a forward pass through the network. + + Args: + x (torch.Tensor): The input tensor to the model. + profile (bool): Print the computation time of each layer if True, defaults to False. + visualize (bool): Save the feature maps of the model if True, defaults to False. + Returns: (torch.Tensor): The last output of the model. """ @@ -69,6 +85,13 @@ class BaseModel(nn.Module): feature_visualization(x, m.type, m.i, save_dir=visualize) return x + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference.""" + LOGGER.warning( + f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' + ) + return self._predict_once(x) + def _profile_one_layer(self, m, x, dt): """ Profile the computation time and FLOPs of a single layer of the model on a given input. @@ -225,13 +248,7 @@ class DetectionModel(BaseModel): self.info() LOGGER.info('') - def predict(self, x, augment=False, profile=False, visualize=False): - """Run forward pass on input image(s) with optional augmentation and profiling.""" - if augment: - return self._forward_augment(x) # augmented inference, None - return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train - - def _forward_augment(self, x): + def _predict_augment(self, x): """Perform augmentations on input image x and return augmented inference and train outputs.""" img_size = x.shape[-2:] # height, width s = [1, 0.83, 0.67] # scales @@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel): """Initialize YOLOv8 segmentation model with given config and parameters.""" super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) - def _forward_augment(self, x): - """Undocumented function.""" - raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!')) - def init_criterion(self): return v8SegmentationLoss(self) + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference.""" + LOGGER.warning( + f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' + ) + return self._predict_once(x) + class PoseModel(DetectionModel): """YOLOv8 pose model.""" @@ -302,9 +322,12 @@ class PoseModel(DetectionModel): def init_criterion(self): return v8PoseLoss(self) - def _forward_augment(self, x): - """Undocumented function.""" - raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!')) + def _predict_augment(self, x): + """Perform augmentations on input image x and return augmented inference.""" + LOGGER.warning( + f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.' + ) + return self._predict_once(x) class ClassificationModel(BaseModel): @@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel): x = head([y[j] for j in head.f], batch) # head inference return x - def _forward_augment(self, x): - """Undocumented function.""" - raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!')) - class Ensemble(nn.ModuleList): """Ensemble of models."""