ultralytics 8.0.105
classification hyp fix and new onplot
callbacks (#2684)
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
@ -71,7 +71,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
@ -126,7 +126,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plots metrics from a CSV file."""
|
||||
plot_results(file=self.csv, classify=True) # save results.png
|
||||
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
@ -147,7 +147,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
plot_images(images=batch['img'],
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg')
|
||||
fname=self.save_dir / f'train_batch{ni}.jpg',
|
||||
on_plot=self.on_plot)
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
@ -47,7 +47,10 @@ class ClassificationValidator(BaseValidator):
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir,
|
||||
names=self.names.values(),
|
||||
normalize=normalize,
|
||||
on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
@ -57,7 +60,7 @@ class ClassificationValidator(BaseValidator):
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
|
||||
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
|
||||
return dataset
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
@ -76,7 +79,8 @@ class ClassificationValidator(BaseValidator):
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=batch['cls'].squeeze(-1),
|
||||
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
|
||||
names=self.names)
|
||||
names=self.names,
|
||||
on_plot=self.on_plot)
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plots predicted bounding boxes on input images and saves the result."""
|
||||
@ -84,7 +88,8 @@ class ClassificationValidator(BaseValidator):
|
||||
batch_idx=torch.arange(len(batch['img'])),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
|
||||
names=self.names) # pred
|
||||
names=self.names,
|
||||
on_plot=self.on_plot) # pred
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
|
Reference in New Issue
Block a user