`ultralytics 8.0.153` YOLO Tasks Cleanup (#4314)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 39395aedc8
commit 822608986c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -52,6 +52,12 @@ Image classification is a computer vision task that involves categorizing an ima
- [Imagewoof](classify/imagewoof.md): A more challenging subset of ImageNet containing 10 dog breed categories for image classification tasks. - [Imagewoof](classify/imagewoof.md): A more challenging subset of ImageNet containing 10 dog breed categories for image classification tasks.
- [MNIST](classify/mnist.md): A dataset of 70,000 grayscale images of handwritten digits for image classification tasks. - [MNIST](classify/mnist.md): A dataset of 70,000 grayscale images of handwritten digits for image classification tasks.
## [Oriented Bounding Boxes (OBB)](obb/index.md)
Oriented Bounding Boxes (OBB) is a method in computer vision for detecting angled objects in images using rotated bounding boxes, often applied to aerial and satellite imagery.
- [DOTAv2](obb/dota-v2.md): A popular OBB aerial imagery dataset with 1.7 million instances and 11,268 images.
## [Multi-Object Tracking](track/index.md) ## [Multi-Object Tracking](track/index.md)
Multi-object tracking is a computer vision technique that involves detecting and tracking multiple objects over time in a video sequence. Multi-object tracking is a computer vision technique that involves detecting and tracking multiple objects over time in a video sequence.

@ -1,12 +1,12 @@
--- ---
comments: true comments: true
description: Dive deep into various oriented bounding box (OBB) dataset formats compatible with the Ultralytics YOLO model. Grasp the nuances of using and converting datasets to this format. description: Dive deep into various oriented bounding box (OBB) dataset formats compatible with Ultralytics YOLO models. Grasp the nuances of using and converting datasets to this format.
keywords: Ultralytics, YOLO, oriented bounding boxes, OBB, dataset formats, label formats, DOTA v2, data conversion keywords: Ultralytics, YOLO, oriented bounding boxes, OBB, dataset formats, label formats, DOTA v2, data conversion
--- ---
# Oriented Bounding Box Datasets Overview # Oriented Bounding Box (OBB) Datasets Overview
Training a precise object detection model with oriented bounding boxes (OBB) requires a thorough dataset. This guide elucidates the various OBB dataset formats compatible with the Ultralytics YOLO model, offering insights into their structure, application, and methods for format conversions. Training a precise object detection model with oriented bounding boxes (OBB) requires a thorough dataset. This guide explains the various OBB dataset formats compatible with Ultralytics YOLO models, offering insights into their structure, application, and methods for format conversions.
## Supported OBB Dataset Formats ## Supported OBB Dataset Formats

@ -160,7 +160,7 @@ Training settings for YOLO models refer to the various hyperparameters and confi
| `single_cls` | `False` | train multi-class data as single-class | | `single_cls` | `False` | train multi-class data as single-class |
| `rect` | `False` | rectangular training with each batch collated for minimum padding | | `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler | | `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs | | `close_mosaic` | `10` | (int) disable mosaic augmentation for final epochs (0 to disable) |
| `resume` | `False` | resume training from last checkpoint | | `resume` | `False` | resume training from last checkpoint |
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |

@ -102,7 +102,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `single_cls` | `False` | train multi-class data as single-class | | `single_cls` | `False` | train multi-class data as single-class |
| `rect` | `False` | rectangular training with each batch collated for minimum padding | | `rect` | `False` | rectangular training with each batch collated for minimum padding |
| `cos_lr` | `False` | use cosine learning rate scheduler | | `cos_lr` | `False` | use cosine learning rate scheduler |
| `close_mosaic` | `0` | (int) disable mosaic augmentation for final epochs | | `close_mosaic` | `10` | (int) disable mosaic augmentation for final epochs (0 to disable) |
| `resume` | `False` | resume training from last checkpoint | | `resume` | `False` | resume training from last checkpoint |
| `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] | | `amp` | `True` | Automatic Mixed Precision (AMP) training, choices=[True, False] |
| `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) | | `fraction` | `1.0` | dataset fraction to train on (default is 1.0, all images in train set) |

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.152' __version__ = '8.0.153'
from ultralytics.hub import start from ultralytics.hub import start
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO

@ -27,7 +27,7 @@ deterministic: True # (bool) whether to enable deterministic mode
single_cls: False # (bool) train multi-class data as single-class single_cls: False # (bool) train multi-class data as single-class
rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val' rect: False # (bool) rectangular training if mode='train' or rectangular validation if mode='val'
cos_lr: False # (bool) use cosine learning rate scheduler cos_lr: False # (bool) use cosine learning rate scheduler
close_mosaic: 10 # (int) disable mosaic augmentation for final epochs close_mosaic: 10 # (int) disable mosaic augmentation for final epochs (0 to disable)
resume: False # (bool) resume training from last checkpoint resume: False # (bool) resume training from last checkpoint
amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check amp: True # (bool) Automatic Mixed Precision (AMP) training, choices=[True, False], True runs AMP check
fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set) fraction: 1.0 # (float) dataset fraction to train on (default is 1.0, all images in train set)

@ -9,7 +9,7 @@ from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter from ultralytics.engine.exporter import Exporter
from ultralytics.hub.utils import HUB_WEB_ROOT from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, emojis,
is_git_dir, yaml_load) is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
@ -448,11 +448,11 @@ class Model:
"""Load model/trainer/validator/predictor.""" """Load model/trainer/validator/predictor."""
try: try:
return self.task_map[self.task][key] return self.task_map[self.task][key]
except Exception: except Exception as e:
name = self.__class__.__name__ name = self.__class__.__name__
mode = inspect.stack()[1][3] # get the function name. mode = inspect.stack()[1][3] # get the function name.
raise NotImplementedError( raise NotImplementedError(
f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.') emojis(f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')) from e
@property @property
def task_map(self): def task_map(self):

@ -51,9 +51,18 @@ class BaseValidator:
device (torch.device): Device to use for validation. device (torch.device): Device to use for validation.
batch_i (int): Current batch index. batch_i (int): Current batch index.
training (bool): Whether the model is in training mode. training (bool): Whether the model is in training mode.
speed (float): Batch processing speed in seconds. names (dict): Class names.
jdict (dict): Dictionary to store validation results. seen: Records the number of images seen so far during validation.
stats: Placeholder for statistics during validation.
confusion_matrix: Placeholder for a confusion matrix.
nc: Number of classes.
iouv: (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (dict): Dictionary to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
""" """
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
@ -65,6 +74,7 @@ class BaseValidator:
save_dir (Path): Directory to save results. save_dir (Path): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress. pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator. args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
""" """
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
@ -74,8 +84,14 @@ class BaseValidator:
self.device = None self.device = None
self.batch_i = None self.batch_i = None
self.training = True self.training = True
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0} self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None self.jdict = None
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f'{self.args.mode}' name = self.args.name or f'{self.args.mode}'
@ -200,26 +216,26 @@ class BaseValidator:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}") LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats return stats
def match_predictions(self, pred_classes: torch.Tensor, true_classes: torch.Tensor, def match_predictions(self, pred_classes, true_classes, iou):
iou: torch.Tensor) -> torch.Tensor:
""" """
Matches predictions to ground truth objects (pred_classes, true_classes) using IoU. Matches predictions to ground truth objects (pred_classes, true_classes) using IoU.
Args: Args:
pred_classes (torch.Tensor): Predicted class indices of shape(N,). pred_classes (torch.Tensor): Predicted class indices of shape(N,).
true_classes (torch.Tensor): Target class indices of shape(M,). true_classes (torch.Tensor): Target class indices of shape(M,).
iou (torch.Tensor): IoU thresholds from 0.50 to 0.95 in space of 0.05.
Returns: Returns:
(torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds. (torch.Tensor): Correct tensor of shape(N,10) for 10 IoU thresholds.
""" """
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool) correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
correct_class = true_classes[:, None] == pred_classes correct_class = true_classes[:, None] == pred_classes
for i in range(len(self.iouv)): for i, iouv in enumerate(self.iouv):
x = torch.where((iou >= self.iouv[i]) & correct_class) # IoU > threshold and classes match x = torch.nonzero(iou.ge(iouv) & correct_class) # IoU > threshold and classes match
if x[0].shape[0]: if x.shape[0]:
# Concatenate [label, detect, iou] # Concatenate [label, detect, iou]
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() matches = torch.cat((x, iou[x[:, 0], x[:, 1]].unsqueeze(1)), 1).cpu().numpy()
if x[0].shape[0] > 1: if x.shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]] matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
# matches = matches[matches[:, 2].argsort()[::-1]] # matches = matches[matches[:, 2].argsort()[::-1]]

@ -44,7 +44,7 @@ class FastSAMValidator(DetectionValidator):
'R', 'mAP50', 'mAP50-95)') 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto.""" """Post-processes YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,

@ -11,7 +11,7 @@ from ultralytics.utils.ops import xyxy2xywh
class NASPredictor(BasePredictor): class NASPredictor(BasePredictor):
def postprocess(self, preds_in, img, orig_imgs): def postprocess(self, preds_in, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects.""" """Postprocess predictions and returns a list of Results objects."""
# Cat boxes and class scores # Cat boxes and class scores
boxes = xyxy2xywh(preds_in[0][0]) boxes = xyxy2xywh(preds_in[0][0])

@ -310,7 +310,7 @@ class Predictor(BasePredictor):
self.done_warmup = True self.done_warmup = True
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses inference output predictions to create detection masks for objects.""" """Post-processes inference output predictions to create detection masks for objects."""
# (N, 1, H, W), (N, 1) # (N, 1, H, W), (N, 1)
pred_masks, pred_scores = preds[:2] pred_masks, pred_scores = preds[:2]
pred_bboxes = preds[2] if self.segment_all else None pred_bboxes = preds[2] if self.segment_all else None

@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions to return Results objects.""" """Post-processes predictions to return Results objects."""
results = [] results = []
for i, pred in enumerate(preds): for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs

@ -43,11 +43,7 @@ class ClassificationTrainer(BaseTrainer):
return model return model
def setup_model(self): def setup_model(self):
""" """load/create/download model for any task"""
load/create/download model for any task
"""
# Classification models require special handling
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return return
@ -65,7 +61,7 @@ class ClassificationTrainer(BaseTrainer):
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.') FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc']) ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume return # do not return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None): def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train') return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
@ -102,9 +98,9 @@ class ClassificationTrainer(BaseTrainer):
def label_loss_items(self, loss_items=None, prefix='train'): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
""" """
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names] keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None: if loss_items is None:
return keys return keys
@ -144,7 +140,7 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg=DEFAULT_CFG, use_python=False): def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model.""" """Train a YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18" model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist") data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else '' device = cfg.device if cfg.device is not None else ''

@ -14,6 +14,8 @@ class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar.""" """Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = 'classify' self.args.task = 'classify'
self.metrics = ClassifyMetrics() self.metrics = ClassifyMetrics()

@ -10,7 +10,7 @@ from ultralytics.utils import DEFAULT_CFG, ROOT, ops
class DetectionPredictor(BasePredictor): class DetectionPredictor(BasePredictor):
def postprocess(self, preds, img, orig_imgs): def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions and returns a list of Results objects.""" """Post-processes predictions and returns a list of Results objects."""
preds = ops.non_max_suppression(preds, preds = ops.non_max_suppression(preds,
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,

@ -13,7 +13,6 @@ from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
# BaseTrainer python usage
class DetectionTrainer(BaseTrainer): class DetectionTrainer(BaseTrainer):
def build_dataset(self, img_path, mode='train', batch=None): def build_dataset(self, img_path, mode='train', batch=None):
@ -69,9 +68,9 @@ class DetectionTrainer(BaseTrainer):
def label_loss_items(self, loss_items=None, prefix='train'): def label_loss_items(self, loss_items=None, prefix='train'):
""" """
Returns a loss dict with labelled training loss items tensor Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
""" """
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names] keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is not None: if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats

@ -20,9 +20,10 @@ class DetectionValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize detection model with necessary variables and settings.""" """Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'detect' self.nt_per_class = None
self.is_coco = False self.is_coco = False
self.class_map = None self.class_map = None
self.args.task = 'detect'
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95 self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel() self.niou = self.iouv.numel()
@ -155,18 +156,23 @@ class DetectionValidator(BaseValidator):
def _process_batch(self, detections, labels): def _process_batch(self, detections, labels):
""" """
Return correct prediction matrix Return correct prediction matrix.
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class Args:
labels (array[M, 5]), class, x1, y1, x2, y2 detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
Returns: Returns:
correct (array[N, 10]), for 10 IoU levels (torch.Tensor): Correct prediction matrix of shape [N, 10] for 10 IoU levels.
""" """
iou = box_iou(labels[:, 1:], detections[:, :4]) iou = box_iou(labels[:, 1:], detections[:, :4])
return self.match_predictions(detections[:, 5], labels[:, 0], iou) return self.match_predictions(detections[:, 5], labels[:, 0], iou)
def build_dataset(self, img_path, mode='val', batch=None): def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset """
Build YOLO Dataset.
Args: Args:
img_path (str): Path to the folder containing images. img_path (str): Path to the folder containing images.

@ -8,7 +8,6 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_results from ultralytics.utils.plotting import plot_images, plot_results
# BaseTrainer python usage
class PoseTrainer(yolo.detect.DetectionTrainer): class PoseTrainer(yolo.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):

@ -17,6 +17,8 @@ class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes.""" """Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = 'pose' self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps': if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
@ -112,14 +114,19 @@ class PoseValidator(DetectionValidator):
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None): def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
""" """
Return correct prediction matrix Return correct prediction matrix.
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class Args:
labels (array[M, 5]), class, x1, y1, x2, y2 detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
pred_kpts (array[N, 51]), 51 = 17 * 3 Each detection is of the format: x1, y1, x2, y2, conf, class.
gt_kpts (array[N, 51]) labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
51 corresponds to 17 keypoints each with 3 values.
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
Returns: Returns:
correct (array[N, 10]), for 10 IoU levels torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
""" """
if pred_kpts is not None and gt_kpts is not None: if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384 # `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384

@ -8,7 +8,6 @@ from ultralytics.utils import DEFAULT_CFG, RANK
from ultralytics.utils.plotting import plot_images, plot_results from ultralytics.utils.plotting import plot_images, plot_results
# BaseTrainer python usage
class SegmentationTrainer(yolo.detect.DetectionTrainer): class SegmentationTrainer(yolo.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):

@ -19,6 +19,8 @@ class SegmentationValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None): def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.""" """Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks) super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.plot_masks = None
self.process = None
self.args.task = 'segment' self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot) self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
@ -44,7 +46,7 @@ class SegmentationValidator(DetectionValidator):
'R', 'mAP50', 'mAP50-95)') 'R', 'mAP50', 'mAP50-95)')
def postprocess(self, preds): def postprocess(self, preds):
"""Postprocesses YOLO predictions and returns output detections with proto.""" """Post-processes YOLO predictions and returns output detections with proto."""
p = ops.non_max_suppression(preds[0], p = ops.non_max_suppression(preds[0],
self.args.conf, self.args.conf,
self.args.iou, self.args.iou,

@ -1,4 +1,4 @@
# Ultralytics YOLO 🚀, GPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
import os import os

Loading…
Cancel
Save