|
|
|
@ -54,6 +54,22 @@ class BaseModel(nn.Module):
|
|
|
|
|
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
|
|
|
|
augment (bool): Augment image during prediction, defaults to False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(torch.Tensor): The last output of the model.
|
|
|
|
|
"""
|
|
|
|
|
if augment:
|
|
|
|
|
return self._predict_augment(x)
|
|
|
|
|
return self._predict_once(x, profile, visualize)
|
|
|
|
|
|
|
|
|
|
def _predict_once(self, x, profile=False, visualize=False):
|
|
|
|
|
"""
|
|
|
|
|
Perform a forward pass through the network.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (torch.Tensor): The input tensor to the model.
|
|
|
|
|
profile (bool): Print the computation time of each layer if True, defaults to False.
|
|
|
|
|
visualize (bool): Save the feature maps of the model if True, defaults to False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(torch.Tensor): The last output of the model.
|
|
|
|
|
"""
|
|
|
|
@ -69,6 +85,13 @@ class BaseModel(nn.Module):
|
|
|
|
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _predict_augment(self, x):
|
|
|
|
|
"""Perform augmentations on input image x and return augmented inference."""
|
|
|
|
|
LOGGER.warning(
|
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
|
|
|
|
)
|
|
|
|
|
return self._predict_once(x)
|
|
|
|
|
|
|
|
|
|
def _profile_one_layer(self, m, x, dt):
|
|
|
|
|
"""
|
|
|
|
|
Profile the computation time and FLOPs of a single layer of the model on a given input.
|
|
|
|
@ -225,13 +248,7 @@ class DetectionModel(BaseModel):
|
|
|
|
|
self.info()
|
|
|
|
|
LOGGER.info('')
|
|
|
|
|
|
|
|
|
|
def predict(self, x, augment=False, profile=False, visualize=False):
|
|
|
|
|
"""Run forward pass on input image(s) with optional augmentation and profiling."""
|
|
|
|
|
if augment:
|
|
|
|
|
return self._forward_augment(x) # augmented inference, None
|
|
|
|
|
return super().predict(x, profile=profile, visualize=visualize) # single-scale inference, train
|
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x):
|
|
|
|
|
def _predict_augment(self, x):
|
|
|
|
|
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
|
|
|
|
img_size = x.shape[-2:] # height, width
|
|
|
|
|
s = [1, 0.83, 0.67] # scales
|
|
|
|
@ -279,13 +296,16 @@ class SegmentationModel(DetectionModel):
|
|
|
|
|
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
|
|
|
|
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x):
|
|
|
|
|
"""Undocumented function."""
|
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ SegmentationModel has not supported augment inference yet!'))
|
|
|
|
|
|
|
|
|
|
def init_criterion(self):
|
|
|
|
|
return v8SegmentationLoss(self)
|
|
|
|
|
|
|
|
|
|
def _predict_augment(self, x):
|
|
|
|
|
"""Perform augmentations on input image x and return augmented inference."""
|
|
|
|
|
LOGGER.warning(
|
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
|
|
|
|
)
|
|
|
|
|
return self._predict_once(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PoseModel(DetectionModel):
|
|
|
|
|
"""YOLOv8 pose model."""
|
|
|
|
@ -302,9 +322,12 @@ class PoseModel(DetectionModel):
|
|
|
|
|
def init_criterion(self):
|
|
|
|
|
return v8PoseLoss(self)
|
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x):
|
|
|
|
|
"""Undocumented function."""
|
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ PoseModel has not supported augment inference yet!'))
|
|
|
|
|
def _predict_augment(self, x):
|
|
|
|
|
"""Perform augmentations on input image x and return augmented inference."""
|
|
|
|
|
LOGGER.warning(
|
|
|
|
|
f'WARNING ⚠️ {self.__class__.__name__} has not supported augment inference yet! Now using single-scale inference instead.'
|
|
|
|
|
)
|
|
|
|
|
return self._predict_once(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClassificationModel(BaseModel):
|
|
|
|
@ -448,10 +471,6 @@ class RTDETRDetectionModel(DetectionModel):
|
|
|
|
|
x = head([y[j] for j in head.f], batch) # head inference
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def _forward_augment(self, x):
|
|
|
|
|
"""Undocumented function."""
|
|
|
|
|
raise NotImplementedError(emojis('WARNING ⚠️ RTDETRModel has not supported augment inference yet!'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Ensemble(nn.ModuleList):
|
|
|
|
|
"""Ensemble of models."""
|
|
|
|
|