ultralytics 8.0.81
single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -10,10 +10,12 @@ from ultralytics.yolo.utils import DEFAULT_CFG, ROOT
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Converts input image to model-compatible data type."""
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Postprocesses predictions to return Results objects."""
|
||||
results = []
|
||||
for i, pred in enumerate(preds):
|
||||
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
|
||||
@ -25,6 +27,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
|
||||
|
||||
def predict(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Run YOLO model predictions on input images/videos."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
|
||||
else 'https://ultralytics.com/images/bus.jpg'
|
||||
|
@ -14,15 +14,18 @@ from ultralytics.yolo.utils.torch_utils import is_parallel, strip_optimizer
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides['task'] = 'classify'
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Returns a modified PyTorch model configured for training YOLO."""
|
||||
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
@ -69,6 +72,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
|
||||
loader = build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size if mode == 'train' else (batch_size * 2),
|
||||
@ -84,19 +88,23 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocesses a batch of images and classes."""
|
||||
batch['img'] = batch['img'].to(self.device)
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a formatted string showing training progress."""
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch['cls'], reduction='sum') / self.args.nbs
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
@ -113,9 +121,11 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resumes training from a given checkpoint."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
@ -130,6 +140,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
|
||||
def train(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Train the YOLO classification model."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
|
||||
device = cfg.device if cfg.device is not None else ''
|
||||
|
@ -9,14 +9,17 @@ from ultralytics.yolo.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
class ClassificationValidator(BaseValidator):
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
|
||||
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
|
||||
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
|
||||
self.args.task = 'classify'
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self):
|
||||
"""Returns a formatted string summarizing classification metrics."""
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
|
||||
@ -24,17 +27,20 @@ class ClassificationValidator(BaseValidator):
|
||||
self.targets = []
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocesses input batch and returns it."""
|
||||
batch['img'] = batch['img'].to(self.device, non_blocking=True)
|
||||
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
|
||||
batch['cls'] = batch['cls'].to(self.device)
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Updates running metrics with model predictions and batch targets."""
|
||||
n5 = min(len(self.model.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
|
||||
self.targets.append(batch['cls'])
|
||||
|
||||
def finalize_metrics(self, *args, **kwargs):
|
||||
"""Finalizes metrics of the model such as confusion_matrix and speed."""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, names=list(self.names.values()))
|
||||
@ -42,10 +48,12 @@ class ClassificationValidator(BaseValidator):
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def get_stats(self):
|
||||
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Builds and returns a data loader for classification tasks with given parameters."""
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
@ -54,11 +62,13 @@ class ClassificationValidator(BaseValidator):
|
||||
workers=self.args.workers)
|
||||
|
||||
def print_results(self):
|
||||
"""Prints evaluation metrics for YOLO object detection model."""
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
|
||||
|
||||
|
||||
def val(cfg=DEFAULT_CFG, use_python=False):
|
||||
"""Validate YOLO model using custom data."""
|
||||
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
|
||||
data = cfg.data or 'mnist160'
|
||||
|
||||
|
Reference in New Issue
Block a user