Update val `max_dets=args.max_det=300` (#3051)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 8ac8ff72ae
commit 3ae81ee9d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -223,7 +223,7 @@ class DetectionValidator(BaseValidator):
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result.""" """Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'], plot_images(batch['img'],
*output_to_target(preds, max_det=15), *output_to_target(preds, max_det=self.args.max_det),
paths=batch['im_file'], paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names, names=self.names,

@ -156,7 +156,7 @@ class PoseValidator(DetectionValidator):
"""Plots predictions for YOLO model.""" """Plots predictions for YOLO model."""
pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0) pred_kpts = torch.cat([p[:, 6:].view(-1, *self.kpt_shape)[:15] for p in preds], 0)
plot_images(batch['img'], plot_images(batch['img'],
*output_to_target(preds, max_det=15), *output_to_target(preds, max_det=self.args.max_det),
kpts=pred_kpts, kpts=pred_kpts,
paths=batch['im_file'], paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',

@ -179,8 +179,9 @@ class SegmentationValidator(DetectionValidator):
def plot_predictions(self, batch, preds, ni): def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes.""" """Plots batch predictions with masks and bounding boxes."""
plot_images(batch['img'], plot_images(
*output_to_target(preds[0], max_det=15), batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks, torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'], paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg', fname=self.save_dir / f'val_batch{ni}_pred.jpg',

Loading…
Cancel
Save