ultralytics 8.0.81
single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -8,6 +8,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor
|
||||
class PosePredictor(DetectionPredictor):
|
||||
|
||||
def postprocess(self, preds, img, orig_img):
|
||||
"""Return detection results for a given input image or list of images."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -35,6 +36,7 @@ class PosePredictor(DetectionPredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Runs YOLO to predict objects in an image or video."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -21,12 +21,14 @@ from ultralytics.yolo.v8.detect.train import Loss
|
||||
class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a PoseTrainer object with specified configurations and overrides."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'pose'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get pose estimation model with specified configuration and weights."""
|
||||
model = PoseModel(cfg, ch=3, nc=self.data['nc'], data_kpt_shape=self.data['kpt_shape'], verbose=verbose)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -34,19 +36,23 @@ class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Sets keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data['kpt_shape']
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = 'box_loss', 'pose_loss', 'kobj_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.pose.PoseValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Computes pose loss for the YOLO model."""
|
||||
if not hasattr(self, 'compute_loss'):
|
||||
self.compute_loss = PoseLoss(de_parallel(self.model))
|
||||
return self.compute_loss(preds, batch)
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plot a batch of training samples with annotated class labels, bounding boxes, and keypoints."""
|
||||
images = batch['img']
|
||||
kpts = batch['keypoints']
|
||||
cls = batch['cls'].squeeze(-1)
|
||||
@ -62,6 +68,7 @@ class PoseTrainer(v8.detect.DetectionTrainer):
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, pose=True) # save results.png
|
||||
|
||||
|
||||
@ -78,6 +85,7 @@ class PoseLoss(Loss):
|
||||
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
"""Calculate the total loss and detach it."""
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
@ -145,6 +153,7 @@ class PoseLoss(Loss):
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def kpts_decode(self, anchor_points, pred_kpts):
|
||||
"""Decodes predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
y[..., 0] += anchor_points[:, [0]] - 0.5
|
||||
@ -153,6 +162,7 @@ class PoseLoss(Loss):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO model on the given data and device."""
|
||||
model = cfg.model or 'yolov8n-pose.yaml'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -15,20 +15,24 @@ from ultralytics.yolo.v8.detect import DetectionValidator
|
||||
class PoseValidator(DetectionValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'pose'
|
||||
self.metrics = PoseMetrics(save_dir=self.save_dir)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch['keypoints'] = batch['keypoints'].to(self.device).float()
|
||||
return batch
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns description of evaluation metrics in string format."""
|
||||
return ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', 'R', 'mAP50', 'mAP50-95)', 'Pose(P',
|
||||
'R', 'mAP50', 'mAP50-95)')
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Apply non-maximum suppression and return detections with high confidence scores."""
|
||||
preds = ops.non_max_suppression(preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
@ -40,6 +44,7 @@ class PoseValidator(DetectionValidator):
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initiate pose estimation metrics for YOLO model."""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data['kpt_shape']
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
@ -137,6 +142,7 @@ class PoseValidator(DetectionValidator):
|
||||
return torch.tensor(correct, dtype=torch.bool, device=detections.device)
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plots and saves validation set samples with predicted bounding boxes and keypoints."""
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
@ -147,6 +153,7 @@ class PoseValidator(DetectionValidator):
|
||||
names=self.names)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predictions for YOLO model."""
|
||||
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0)
|
||||
plot_images(batch['img'],
|
||||
*output_to_target(preds, max_det=15),
|
||||
@ -156,6 +163,7 @@ class PoseValidator(DetectionValidator):
|
||||
names=self.names) # pred
|
||||
|
||||
def pred_to_json(self, predn, filename):
|
||||
"""Converts YOLO predictions to COCO JSON format."""
|
||||
stem = Path(filename).stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn[:, :4]) # xywh
|
||||
@ -169,6 +177,7 @@ class PoseValidator(DetectionValidator):
|
||||
'score': round(p[4], 5)})
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluates object detection model using COCO JSON format."""
|
||||
if self.args.save_json and self.is_coco and len(self.jdict):
|
||||
anno_json = self.data['path'] / 'annotations/person_keypoints_val2017.json' # annotations
|
||||
pred_json = self.save_dir / 'predictions.json' # predictions
|
||||
@ -197,6 +206,7 @@ class PoseValidator(DetectionValidator):
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Performs validation on YOLO model using given data."""
|
||||
model = cfg.model or 'yolov8n-pose.pt'
|
||||
data = cfg.data or 'coco8-pose.yaml'
|
||||
|
||||
|
Reference in New Issue
Block a user