ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-17 19:10:20 +02:00
committed by GitHub
parent b1119d512e
commit 23fc50641c
92 changed files with 378 additions and 206 deletions

View File

@ -45,17 +45,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
images = batch['img']
masks = batch['masks']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True) # save results.png
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses

View File

@ -20,7 +20,7 @@ class SegmentationValidator(DetectionValidator):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir)
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
@ -174,7 +174,8 @@ class SegmentationValidator(DetectionValidator):
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
@ -183,7 +184,8 @@ class SegmentationValidator(DetectionValidator):
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):