ultralytics 8.0.153
YOLO Tasks Cleanup (#4314)
This commit is contained in:
@ -44,7 +44,7 @@ class FastSAMValidator(DetectionValidator):
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
"""Post-processes YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
@ -11,7 +11,7 @@ from ultralytics.utils.ops import xyxy2xywh
|
||||
class NASPredictor(BasePredictor):
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""Postprocesses predictions and returns a list of Results objects."""
|
||||
"""Postprocess predictions and returns a list of Results objects."""
|
||||
|
||||
# Cat boxes and class scores
|
||||
boxes = xyxy2xywh(preds_in[0][0])
|
||||
|
@ -310,7 +310,7 @@ class Predictor(BasePredictor):
|
||||
self.done_warmup = True
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses inference output predictions to create detection masks for objects."""
|
||||
"""Post-processes inference output predictions to create detection masks for objects."""
|
||||
# (N, 1, H, W), (N, 1)
|
||||
pred_masks, pred_scores = preds[:2]
|
||||
pred_bboxes = preds[2] if self.segment_all else None
|
||||
|
@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions to return Results objects."""
|
||||
"""Post-processes predictions to return Results objects."""
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
|
@ -43,11 +43,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
"""
|
||||
# Classification models require special handling
|
||||
|
||||
"""load/create/download model for any task"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
@ -65,7 +61,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
return # do not return ckpt. Classification doesn't support resume
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
||||
@ -102,9 +98,9 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
@ -144,7 +140,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO classification model."""
|
||||
"""Train a YOLO classification model."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -14,6 +14,8 @@ class ClassificationValidator(BaseValidator):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
|
@ -10,7 +10,7 @@ from ultralytics.utils import DEFAULT_CFG, ROOT, ops
|
||||
class DetectionPredictor(BasePredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions and returns a list of Results objects."""
|
||||
"""Post-processes predictions and returns a list of Results objects."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
@ -13,7 +13,6 @@ from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
@ -69,9 +68,9 @@ class DetectionTrainer(BaseTrainer):
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
|
||||
segmentation & detection
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f'{prefix}/{x}' for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
|
@ -20,9 +20,10 @@ class DetectionValidator(BaseValidator):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize detection model with necessary variables and settings."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'detect'
|
||||
self.nt_per_class = None
|
||||
self.is_coco = False
|
||||
self.class_map = None
|
||||
self.args.task = 'detect'
|
||||
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
|
||||
self.niou = self.iouv.numel()
|
||||
@ -155,18 +156,23 @@ class DetectionValidator(BaseValidator):
|
||||
|
||||
def _process_batch(self, detections, labels):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
||||
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
||||
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
||||
Each label is of the format: class, x1, y1, x2, y2.
|
||||
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
||||
"""
|
||||
iou = box_iou(labels[:, 1:], detections[:, :4])
|
||||
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
|
||||
|
||||
def build_dataset(self, img_path, mode='val', batch=None):
|
||||
"""Build YOLO Dataset
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
|
@ -8,7 +8,6 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
|
@ -17,6 +17,8 @@ class PoseValidator(DetectionValidator):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.sigma = None
|
||||
self.kpt_shape = None
|
||||
self.args.task = 'pose'
|
||||
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
|
||||
@ -112,14 +114,19 @@ class PoseValidator(DetectionValidator):
|
||||
|
||||
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
|
||||
"""
|
||||
Return correct prediction matrix
|
||||
Arguments:
|
||||
detections (array[N, 6]), x1, y1, x2, y2, conf, class
|
||||
labels (array[M, 5]), class, x1, y1, x2, y2
|
||||
pred_kpts (array[N, 51]), 51 = 17 * 3
|
||||
gt_kpts (array[N, 51])
|
||||
Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
|
||||
Each detection is of the format: x1, y1, x2, y2, conf, class.
|
||||
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
|
||||
Each label is of the format: class, x1, y1, x2, y2.
|
||||
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
|
||||
51 corresponds to 17 keypoints each with 3 values.
|
||||
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
|
||||
|
||||
Returns:
|
||||
correct (array[N, 10]), for 10 IoU levels
|
||||
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
|
||||
"""
|
||||
if pred_kpts is not None and gt_kpts is not None:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
|
@ -8,7 +8,6 @@ from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
from ultralytics.utils.plotting import plot_images, plot_results
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
|
@ -19,6 +19,8 @@ class SegmentationValidator(DetectionValidator):
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.plot_masks = None
|
||||
self.process = None
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
@ -44,7 +46,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocesses YOLO predictions and returns output detections with proto."""
|
||||
"""Post-processes YOLO predictions and returns output detections with proto."""
|
||||
p = ops.non_max_suppression(preds[0],
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
|
Reference in New Issue
Block a user