Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 00:34:34 +05:30
committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
14 changed files with 199 additions and 71 deletions

View File

@ -1,5 +1,3 @@
from pathlib import Path
import hydra
import torch
import torchvision
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
super().__init__(config, overrides)
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
# Update defaults
if self.args.imgsz == 640:
self.args.imgsz = 224
return model
def setup_model(self):
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = self.model
pretrained = False
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
pretrained = True
else:
self.model = attempt_load_weights(model, device='cpu')
elif model.endswith(".yaml"):
self.model = self.get_model(cfg=model)
# order: check local file -> torchvision assets -> ultralytics asset
if Path(f"{model}.pt").is_file(): # local file
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
return # dont return ckpt. Classification doesn't support resume
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
def check_resume(self):
pass
def resume_training(self, ckpt):
pass
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = cfg.model or "resnet18"
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
if __name__ == "__main__":