ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)

This commit is contained in:
Glenn Jocher
2023-08-12 02:30:57 +02:00
committed by GitHub
parent 39395aedc8
commit 822608986c
22 changed files with 87 additions and 55 deletions

View File

@ -8,7 +8,6 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.plotting import plot_images, plot_results
# BaseTrainer python usage
class PoseTrainer(yolo.detect.DetectionTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):

View File

@ -17,6 +17,8 @@ class PoseValidator(DetectionValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
if isinstance(self.args.device, str) and self.args.device.lower() == 'mps':
@ -112,14 +114,19 @@ class PoseValidator(DetectionValidator):
def _process_batch(self, detections, labels, pred_kpts=None, gt_kpts=None):
"""
Return correct prediction matrix
Arguments:
detections (array[N, 6]), x1, y1, x2, y2, conf, class
labels (array[M, 5]), class, x1, y1, x2, y2
pred_kpts (array[N, 51]), 51 = 17 * 3
gt_kpts (array[N, 51])
Return correct prediction matrix.
Args:
detections (torch.Tensor): Tensor of shape [N, 6] representing detections.
Each detection is of the format: x1, y1, x2, y2, conf, class.
labels (torch.Tensor): Tensor of shape [M, 5] representing labels.
Each label is of the format: class, x1, y1, x2, y2.
pred_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing predicted keypoints.
51 corresponds to 17 keypoints each with 3 values.
gt_kpts (torch.Tensor, optional): Tensor of shape [N, 51] representing ground truth keypoints.
Returns:
correct (array[N, 10]), for 10 IoU levels
torch.Tensor: Correct prediction matrix of shape [N, 10] for 10 IoU levels.
"""
if pred_kpts is not None and gt_kpts is not None:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384