ultralytics 8.0.89 SAM predict and auto-annotate (#2298)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: Paula Derrenger <107626595+pderrenger@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
Co-authored-by: Snyk bot <snyk-bot@snyk.io>
Co-authored-by: Laughing-q <1185102784@qq.com>
This commit is contained in:
Glenn Jocher
2023-04-28 00:36:50 +02:00
committed by GitHub
parent 3e118f6170
commit 243fc4b1fe
44 changed files with 2915 additions and 440 deletions

View File

@ -1,9 +1,12 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.yolo.data import build_classification_dataloader
import torch
from ultralytics.yolo.data import ClassificationDataset, build_dataloader
from ultralytics.yolo.engine.validator import BaseValidator
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER
from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.yolo.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
@ -52,20 +55,36 @@ class ClassificationValidator(BaseValidator):
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path):
dataset = ClassificationDataset(root=img_path, imgsz=self.args.imgsz, augment=False)
return dataset
def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
batch_size=batch_size,
augment=False,
shuffle=False,
workers=self.args.workers)
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names) # pred
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate YOLO model using custom data."""