ultralytics 8.0.105 classification hyp fix and new onplot callbacks (#2684)

Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <shcheklein@gmail.com>
This commit is contained in:
Glenn Jocher
2023-05-17 19:10:20 +02:00
committed by GitHub
parent b1119d512e
commit 23fc50641c
92 changed files with 378 additions and 206 deletions

View File

@ -71,7 +71,7 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=mode == 'train')
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
@ -126,7 +126,7 @@ class ClassificationTrainer(BaseTrainer):
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True) # save results.png
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
@ -147,7 +147,8 @@ class ClassificationTrainer(BaseTrainer):
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):

View File

@ -47,7 +47,10 @@ class ClassificationValidator(BaseValidator):
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
@ -57,7 +60,7 @@ class ClassificationValidator(BaseValidator):
return self.metrics.results_dict
def build_dataset(self, img_path):
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
dataset = ClassificationDataset(root=img_path, args=self.args, augment=False)
return dataset
def get_dataloader(self, dataset_path, batch_size):
@ -76,7 +79,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
@ -84,7 +88,8 @@ class ClassificationValidator(BaseValidator):
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):

View File

@ -121,17 +121,18 @@ class DetectionTrainer(BaseTrainer):
cls=batch['cls'].squeeze(-1),
bboxes=batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv) # save results.png
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
# Criterion class for computing training losses

View File

@ -24,7 +24,7 @@ class DetectionValidator(BaseValidator):
self.args.task = 'detect'
self.is_coco = False
self.class_map = None
self.metrics = DetMetrics(save_dir=self.save_dir)
self.metrics = DetMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
self.iouv = torch.linspace(0.5, 0.95, 10) # iou vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
@ -145,7 +145,10 @@ class DetectionValidator(BaseValidator):
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, names=self.names.values(), normalize=normalize)
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
def _process_batch(self, detections, labels):
"""
@ -215,7 +218,8 @@ class DetectionValidator(BaseValidator):
batch['bboxes'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
@ -223,7 +227,8 @@ class DetectionValidator(BaseValidator):
*output_to_target(preds, max_det=15),
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""

View File

@ -65,11 +65,12 @@ class PoseTrainer(v8.detect.DetectionTrainer):
bboxes,
kpts=kpts,
paths=paths,
fname=self.save_dir / f'train_batch{ni}.jpg')
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, pose=True) # save results.png
plot_results(file=self.csv, pose=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses

View File

@ -18,7 +18,7 @@ class PoseValidator(DetectionValidator):
"""Initialize a 'PoseValidator' object with custom parameters and assigned attributes."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'pose'
self.metrics = PoseMetrics(save_dir=self.save_dir)
self.metrics = PoseMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses the batch by converting the 'keypoints' data into a float and moving it to the device."""
@ -150,7 +150,8 @@ class PoseValidator(DetectionValidator):
kpts=batch['keypoints'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predictions for YOLO model."""
@ -160,7 +161,8 @@ class PoseValidator(DetectionValidator):
kpts=pred_kpts,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
def pred_to_json(self, predn, filename):
"""Converts YOLO predictions to COCO JSON format."""

View File

@ -45,17 +45,18 @@ class SegmentationTrainer(v8.detect.DetectionTrainer):
def plot_training_samples(self, batch, ni):
"""Creates a plot of training sample images with labels and box coordinates."""
images = batch['img']
masks = batch['masks']
cls = batch['cls'].squeeze(-1)
bboxes = batch['bboxes']
paths = batch['im_file']
batch_idx = batch['batch_idx']
plot_images(images, batch_idx, cls, bboxes, masks, paths=paths, fname=self.save_dir / f'train_batch{ni}.jpg')
plot_images(batch['img'],
batch['batch_idx'],
batch['cls'].squeeze(-1),
batch['bboxes'],
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def plot_metrics(self):
"""Plots training/val metrics."""
plot_results(file=self.csv, segment=True) # save results.png
plot_results(file=self.csv, segment=True, on_plot=self.on_plot) # save results.png
# Criterion class for computing training losses

View File

@ -20,7 +20,7 @@ class SegmentationValidator(DetectionValidator):
"""Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'segment'
self.metrics = SegmentMetrics(save_dir=self.save_dir)
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
def preprocess(self, batch):
"""Preprocesses batch by converting masks to float and sending to device."""
@ -174,7 +174,8 @@ class SegmentationValidator(DetectionValidator):
batch['masks'],
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots batch predictions with masks and bounding boxes."""
@ -183,7 +184,8 @@ class SegmentationValidator(DetectionValidator):
torch.cat(self.plot_masks, dim=0) if len(self.plot_masks) else self.plot_masks,
paths=batch['im_file'],
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
names=self.names,
on_plot=self.on_plot) # pred
self.plot_masks.clear()
def pred_to_json(self, predn, filename, pred_masks):