ultralytics 8.0.105
classification hyp fix and new onplot
callbacks (#2684)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
@ -45,17 +45,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Creates a plot of training sample images with labels and box coordinates."""
|
||||
images = batch['img']
|
||||
masks = batch['masks']
|
||||
cls = batch['cls'].squeeze(-1)
|
||||
bboxes = batch['bboxes']
|
||||
paths = batch['im_file']
|
||||
batch_idx = batch['batch_idx']
|
||||
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
plot_images(batch['img'],
|
||||
batch['batch_idx'],
|
||||
batch['cls'].squeeze(-1),
|
||||
batch['bboxes'],
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots training/val metrics."""
|
||||
plot_results(file=self.csv, segment=True) # save results.png
|
||||
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
|
||||
# Criterion class for computing training losses
|
||||
|
@ -20,7 +20,7 @@ class SegmentationValidator(DetectionValidator):
|
||||
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'segment'
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir)
|
||||
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses batch by converting masks to float and sending to device."""
|
||||
@ -174,7 +174,8 @@ class SegmentationValidator(DetectionValidator):
|
||||
batch['masks'],
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names)
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots batch predictions with masks and bounding boxes."""
|
||||
@ -183,7 +184,8 @@ class SegmentationValidator(DetectionValidator):
|
||||
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
|
||||
paths=batch['im_file'],
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names) # pred
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
self.plot_masks.clear()
|
||||
|
||||
def pred_to_json(self, predn, filename, pred_masks):
|
||||
|
Reference in New Issue
Block a user