Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 00:34:34 +05:30
committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
14 changed files with 199 additions and 71 deletions

View File

@ -1,5 +1,3 @@
from pathlib import Path
import hydra
import torch
import torchvision
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
super().__init__(config, overrides)
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
# Update defaults
if self.args.imgsz == 640:
self.args.imgsz = 224
return model
def setup_model(self):
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = self.model
pretrained = False
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
pretrained = True
else:
self.model = attempt_load_weights(model, device='cpu')
elif model.endswith(".yaml"):
self.model = self.get_model(cfg=model)
# order: check local file -> torchvision assets -> ultralytics asset
if Path(f"{model}.pt").is_file(): # local file
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
return # dont return ckpt. Classification doesn't support resume
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
def check_resume(self):
pass
def resume_training(self, ckpt):
pass
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = cfg.model or "resnet18"
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
if __name__ == "__main__":

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.00 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.67 # scales module repeats
width_multiple: 0.75 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.50 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]