ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)

This commit is contained in:
Glenn Jocher
2023-08-12 02:30:57 +02:00
committed by GitHub
parent 39395aedc8
commit 822608986c
22 changed files with 87 additions and 55 deletions

View File

@ -10,7 +10,7 @@ from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class DetectionPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects."""
"""Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds,
self.args.conf,
self.args.iou,

View File

@ -13,7 +13,6 @@ from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
# BaseTrainer python usage
class DetectionTrainer(BaseTrainer):
def build_dataset(self, img_path, mode='train', batch=None):
@ -69,9 +68,9 @@ class DetectionTrainer(BaseTrainer):
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats

View File

@ -20,9 +20,10 @@ class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'detect'
self.nt_per_class = None
self.is_coco = False
self.class_map = None
self.args.task = 'detect'
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
@ -155,18 +156,23 @@ class DetectionValidator(BaseValidator):
def _process_batch(self, detections, labels):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
Returns:
correct (array[N, 10]), for 10 IoU levels
(torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou)
def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.