From 3ae81ee9d11c432189e109d7a1724635a2e451ca Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Jun 2023 08:17:39 +0200 Subject: [PATCH] Update val `max_dets=args.max_det=300` (#3051) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/v8/detect/val.py | 2 +- ultralytics/yolo/v8/pose/val.py | 2 +- ultralytics/yolo/v8/segment/val.py | 15 ++++++++------- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/ultralytics/yolo/v8/detect/val.py b/ultralytics/yolo/v8/detect/val.py index 7759a43..77d346c 100644 --- a/ultralytics/yolo/v8/detect/val.py +++ b/ultralytics/yolo/v8/detect/val.py @@ -223,7 +223,7 @@ class DetectionValidator(BaseValidator): def plot_predictions(self, batch, preds, ni): """Plots predicted bounding boxes on input images and saves the result.""" plot_images(batch['img'], - *output_to_target(preds, max_det=15), + *output_to_target(preds, max_det=self.args.max_det), paths=batch['im_file'], fname=self.save_dir / f'val_batch{ni}_pred.jpg', names=self.names, diff --git a/ultralytics/yolo/v8/pose/val.py b/ultralytics/yolo/v8/pose/val.py index a062727..16afba4 100644 --- a/ultralytics/yolo/v8/pose/val.py +++ b/ultralytics/yolo/v8/pose/val.py @@ -156,7 +156,7 @@ class PoseValidator(DetectionValidator): """Plots predictions for YOLO model.""" pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0) plot_images(batch['img'], - *output_to_target(preds, max_det=15), + *output_to_target(preds, max_det=self.args.max_det), kpts=pred_kpts, paths=batch['im_file'], fname=self.save_dir / f'val_batch{ni}_pred.jpg', diff --git a/ultralytics/yolo/v8/segment/val.py b/ultralytics/yolo/v8/segment/val.py index 5c13885..73c2fe8 100644 --- a/ultralytics/yolo/v8/segment/val.py +++ b/ultralytics/yolo/v8/segment/val.py @@ -179,13 +179,14 @@ class SegmentationValidator(DetectionValidator): def plot_predictions(self, batch, preds, ni): """Plots batch predictions with masks and bounding boxes.""" - plot_images(batch['img'], - *output_to_target(preds[0], max_det=15), - torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, - paths=batch['im_file'], - fname=self.save_dir / f'val_batch{ni}_pred.jpg', - names=self.names, - on_plot=self.on_plot) # pred + plot_images( + batch['img'], + *output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed + torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, + paths=batch['im_file'], + fname=self.save_dir / f'val_batch{ni}_pred.jpg', + names=self.names, + on_plot=self.on_plot) # pred self.plot_masks.clear() def pred_to_json(self, predn, filename, pred_masks):