Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
@ -1,5 +1,3 @@
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torchvision
|
||||
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
super().__init__(config, overrides)
|
||||
|
||||
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
# Update defaults
|
||||
if self.args.imgsz == 640:
|
||||
self.args.imgsz = 224
|
||||
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
model = self.model
|
||||
pretrained = False
|
||||
model = str(self.model)
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
model = model.split(".")[0]
|
||||
pretrained = True
|
||||
else:
|
||||
self.model = attempt_load_weights(model, device='cpu')
|
||||
elif model.endswith(".yaml"):
|
||||
self.model = self.get_model(cfg=model)
|
||||
|
||||
# order: check local file -> torchvision assets -> ultralytics asset
|
||||
if Path(f"{model}.pt").is_file(): # local file
|
||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
elif model in torchvision.models.__dict__:
|
||||
pretrained = True
|
||||
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
|
||||
else:
|
||||
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
|
||||
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
|
||||
batch["cls"] = batch["cls"].to(self.device)
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' *
|
||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
||||
|
||||
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
||||
return loss, loss
|
||||
|
||||
def check_resume(self):
|
||||
pass
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
pass
|
||||
|
||||
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "resnet18"
|
||||
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = ClassificationTrainer(cfg)
|
||||
trainer.train()
|
||||
# trainer = ClassificationTrainer(cfg)
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Reference in New Issue
Block a user