ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
@ -47,7 +47,10 @@ class ClassificationValidator(BaseValidator):
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
||||
names=self.names.values(),
|
||||
normalize=normalize,
|
||||
on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
@ -57,7 +60,7 @@ class ClassificationValidator(BaseValidator):
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
|
||||
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
@ -76,7 +79,8 @@ class ClassificationValidator(BaseValidator):
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names)
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
@ -84,7 +88,8 @@ class ClassificationValidator(BaseValidator):
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names) # pred
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
||||
Reference in New Issue
Block a user