Update val max_dets=args.max_det=300 (#3051)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-06-07 08:17:39 +02:00
committed by GitHub
parent 8ac8ff72ae
commit 3ae81ee9d1
3 changed files with 10 additions and 9 deletions

View File

@ -179,13 +179,14 @@ class SegmentationValidator(DetectionValidator):
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
plot_images(batch['img'],
*output_to_target(preds[0], max_det=15),
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
plot_images(
batch['img'],
*output_to_target(preds[0], max_det=15), # not set to self.args.max_det due to slow plotting speed
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):