ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-17 19:10:20 +02:00
committed by GitHub
parent b1119d512e
commit 23fc50641c
92 changed files with 378 additions and 206 deletions

View File

@ -71,7 +71,7 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
@ -126,7 +126,7 @@ class ClassificationTrainer(BaseTrainer):
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True) # save results.png
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
@ -147,7 +147,8 @@ class ClassificationTrainer(BaseTrainer):
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):

View File

@ -47,7 +47,10 @@ class ClassificationValidator(BaseValidator):
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
@ -57,7 +60,7 @@ class ClassificationValidator(BaseValidator):
return self.metrics.results_dict
def build_dataset(self, img_path):
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
return dataset
def get_dataloader(self, dataset_path, batch_size):
@ -76,7 +79,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
@ -84,7 +88,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):