ultralytics 8.0.153 YOLO Tasks Cleanup (#4314)

This commit is contained in:
Glenn Jocher
2023-08-12 02:30:57 +02:00
committed by GitHub
parent 39395aedc8
commit 822608986c
22 changed files with 87 additions and 55 deletions

View File

@ -21,7 +21,7 @@ class ClassificationPredictor(BasePredictor):
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""Postprocesses predictions to return Results objects."""
"""Post-processes predictions to return Results objects."""
results = []
for i, pred in enumerate(preds):
orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs

View File

@ -43,11 +43,7 @@ class ClassificationTrainer(BaseTrainer):
return model
def setup_model(self):
"""
load/create/download model for any task
"""
# Classification models require special handling
"""load/create/download model for any task"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -65,7 +61,7 @@ class ClassificationTrainer(BaseTrainer):
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
ClassificationModel.reshape_outputs(self.model, self.data['nc'])
return # dont return ckpt. Classification doesn't support resume
return # do not return ckpt. Classification doesn't support resume
def build_dataset(self, img_path, mode='train', batch=None):
return ClassificationDataset(root=img_path, args=self.args, augment=mode == 'train')
@ -102,9 +98,9 @@ class ClassificationTrainer(BaseTrainer):
def label_loss_items(self, loss_items=None, prefix='train'):
"""
Returns a loss dict with labelled training loss items tensor
Returns a loss dict with labelled training loss items tensor. Not needed for classification but necessary for
segmentation & detection
"""
# Not needed for classification but necessary for segmentation & detection
keys = [f'{prefix}/{x}' for x in self.loss_names]
if loss_items is None:
return keys
@ -144,7 +140,7 @@ class ClassificationTrainer(BaseTrainer):
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train the YOLO classification model."""
"""Train a YOLO classification model."""
model = cfg.model or 'yolov8n-cls.pt' # or "resnet18"
data = cfg.data or 'mnist160' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''

View File

@ -14,6 +14,8 @@ class ClassificationValidator(BaseValidator):
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""Initializes ClassificationValidator instance with args, dataloader, save_dir, and progress bar."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = 'classify'
self.metrics = ClassifyMetrics()