ultralytics 8.0.136 refactor and simplify package (#3748)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing
2023-07-16 23:47:45 +08:00
committed by GitHub
parent 8ebe94d1e9
commit 620f3eb218
383 changed files with 4213 additions and 4646 deletions

View File

@ -0,0 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor, predict
from ultralytics.models.yolo.classify.train import ClassificationTrainer, train
from ultralytics.models.yolo.classify.val import ClassificationValidator, val
__all__ = 'ClassificationPredictor', 'predict', 'ClassificationTrainer', 'train', 'ClassificationValidator', 'val'

View File

@ -0,0 +1,51 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ROOT
class ClassificationPredictor(BasePredictor):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
img = torch.stack([self.transforms(im) for im in img], dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions to return Results objects."""
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs
path = self.batch[0]
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred))
return results
def predict(cfg=DEFAULT_CFG, use_python=False):
"""Run YOLO model predictions on input images/videos."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
source = cfg.source if cfg.source is not None else ROOT / 'assets' if (ROOT / 'assets').exists() \
else 'https://ultralytics.com/images/bus.jpg'
args = dict(model=model, source=source)
if use_python:
from ultralytics import YOLO
YOLO(model)(**args)
else:
predictor = ClassificationPredictor(overrides=args)
predictor.predict_cli()
if __name__ == '__main__':
predict()

View File

@ -0,0 +1,161 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
import torchvision
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, colorstr
from ultralytics.utils.plotting import plot_images, plot_results
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""Initialize a ClassificationTrainer object with optional configuration overrides and callbacks."""
if overrides is None:
overrides = {}
overrides['task'] = 'classify'
if overrides.get('imgsz') is None:
overrides['imgsz'] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data['names']
def get_model(self, cfg=None, weights=None, verbose=True):
"""Returns a modified PyTorch model configured for training YOLO."""
model = ClassificationModel(cfg, nc=self.data['nc'], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, 'reset_parameters'):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""
load/create/download model for any task
"""
# Classification models require special handling
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith('.pt'):
self.model, _ = attempt_load_one_weight(model, device='cpu')
for p in self.model.parameters():
p.requires_grad = True # for training
elif model.endswith('.yaml'):
self.model = self.get_model(cfg=model)
elif model in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if self.args.pretrained else None)
else:
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
"""Returns PyTorch DataLoader with transforms to preprocess images for inference."""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank)
# Attach inference transforms
if mode != 'train':
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch):
"""Preprocesses a batch of images and classes."""
batch['img'] = batch['img'].to(self.device)
batch['cls'] = batch['cls'].to(self.device)
return batch
def progress_string(self):
"""Returns a formatted string showing training progress."""
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
"""Returns an instance of ClassificationValidator for validation."""
self.loss_names = ['loss']
return yolo.classify.ClassificationValidator(self.test_loader, self.save_dir)
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def resume_training(self, ckpt):
"""Resumes training from a given checkpoint."""
pass
def plot_metrics(self):
"""Plots metrics from a CSV file."""
plot_results(file=self.csv, classify=True, on_plot=self.on_plot) # save results.png
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
# TODO: validate best.pt after training completes
# if f is self.best:
# LOGGER.info(f'\nValidating {f}...')
# self.validator.args.save_json = True
# self.metrics = self.validator(model=f)
# self.metrics.pop('fitness', None)
# self.run_callbacks('on_fit_epoch_end')
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
def plot_training_samples(self, batch, ni):
"""Plots training samples with their annotations."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'train_batch{ni}.jpg',
on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = ClassificationTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()

View File

@ -0,0 +1,109 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import DEFAULT_CFG, LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = 'classify'
self.metrics = ClassifyMetrics()
def get_desc(self):
"""Returns a formatted string summarizing classification metrics."""
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
def init_metrics(self, model):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
self.pred = []
self.targets = []
def preprocess(self, batch):
"""Preprocesses input batch and returns it."""
batch['img'] = batch['img'].to(self.device, non_blocking=True)
batch['img'] = batch['img'].half() if self.args.half else batch['img'].float()
batch['cls'] = batch['cls'].to(self.device)
return batch
def update_metrics(self, preds, batch):
"""Updates running metrics with model predictions and batch targets."""
n5 = min(len(self.model.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):
"""Finalizes metrics of the model such as confusion_matrix and speed."""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir,
names=self.names.values(),
normalize=normalize,
on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
def get_stats(self):
"""Returns a dictionary of metrics obtained by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path):
return ClassificationDataset(root=img_path, args=self.args, augment=False)
def get_dataloader(self, dataset_path, batch_size):
"""Builds and returns a data loader for classification tasks with given parameters."""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self):
"""Prints evaluation metrics for YOLO object detection model."""
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
LOGGER.info(pf % ('all', self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch, ni):
"""Plot validation image samples."""
plot_images(images=batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=batch['cls'].squeeze(-1),
fname=self.save_dir / f'val_batch{ni}_labels.jpg',
names=self.names,
on_plot=self.on_plot)
def plot_predictions(self, batch, preds, ni):
"""Plots predicted bounding boxes on input images and saves the result."""
plot_images(batch['img'],
batch_idx=torch.arange(len(batch['img'])),
cls=torch.argmax(preds, dim=1),
fname=self.save_dir / f'val_batch{ni}_pred.jpg',
names=self.names,
on_plot=self.on_plot) # pred
def val(cfg=DEFAULT_CFG, use_python=False):
"""Validate YOLO model using custom data."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160'
args = dict(model=model, data=data)
if use_python:
from ultralytics import YOLO
YOLO(model).val(**args)
else:
validator = ClassificationValidator(args=args)
validator(model=args['model'])
if __name__ == '__main__':
val()