Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 00:34:34 +05:30
committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
14 changed files with 199 additions and 71 deletions

View File

@ -255,12 +255,28 @@ def check_dataset_yaml(data, autodownload=True):
def check_dataset(dataset: str):
data = Path.cwd() / "datasets" / dataset
data_dir = data if data.is_dir() else (Path.cwd() / data)
"""
Check a classification dataset such as Imagenet.
Copy code
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
If the dataset is not found, it attempts to download the dataset from the internet and save it to the local file system.
Args:
dataset (str): Name of the dataset.
Returns:
data (dict): A dictionary containing the following keys and values:
'train': Path object for the directory containing the training set of the dataset
'val': Path object for the directory containing the validation set of the dataset
'nc': Number of classes in the dataset
'names': List of class names in the dataset
"""
data_dir = (Path.cwd() / "datasets" / dataset).resolve()
if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
t = time.time()
if str(data) == 'imagenet':
if dataset == 'imagenet':
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
@ -271,5 +287,4 @@ def check_dataset(dataset: str):
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [name for name in os.listdir(data_dir / 'train') if os.path.isdir(data_dir / 'train' / name)]
data = {"train": train_set, "val": test_set, "nc": nc, "names": names}
return data
return {"train": train_set, "val": test_set, "nc": nc, "names": names}

View File

@ -103,13 +103,9 @@ class YOLO:
Args:
verbose (bool): Controls verbosity.
"""
if not self.model:
LOGGER.info("model not initialized!")
self.model.info(verbose=verbose)
def fuse(self):
if not self.model:
LOGGER.info("model not initialized!")
self.model.fuse()
@smart_inference_mode()
@ -139,9 +135,6 @@ class YOLO:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
if not self.model:
raise ModuleNotFoundError("model not initialized!")
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
@ -177,8 +170,6 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
"""
if not self.model:
raise AttributeError("model not initialized. Use .new() or .load()")
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
@ -193,10 +184,8 @@ class YOLO:
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
cfg=self.model.yaml if self.task != "classify" else None)
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
def to(self, device):

View File

@ -1,5 +1,3 @@
from pathlib import Path
import hydra
import torch
import torchvision
@ -13,7 +11,9 @@ from ultralytics.yolo.utils import DEFAULT_CONFIG
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CONFIG, overrides={}):
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
super().__init__(config, overrides)
@ -25,6 +25,10 @@ class ClassificationTrainer(BaseTrainer):
if weights:
model.load(weights)
# Update defaults
if self.args.imgsz == 640:
self.args.imgsz = 224
return model
def setup_model(self):
@ -36,22 +40,17 @@ class ClassificationTrainer(BaseTrainer):
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
model = self.model
pretrained = False
model = str(self.model)
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"):
model = model.split(".")[0]
pretrained = True
else:
self.model = attempt_load_weights(model, device='cpu')
elif model.endswith(".yaml"):
self.model = self.get_model(cfg=model)
# order: check local file -> torchvision assets -> ultralytics asset
if Path(f"{model}.pt").is_file(): # local file
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
elif model in torchvision.models.__dict__:
pretrained = True
self.model = torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else:
self.model = attempt_load_weights(f"{model}.pt", device='cpu')
FileNotFoundError(f'ERROR: model={model} not found locally or online. Please check model name.')
return # dont return ckpt. Classification doesn't support resume
@ -66,6 +65,10 @@ class ClassificationTrainer(BaseTrainer):
batch["cls"] = batch["cls"].to(self.device)
return batch
def progress_string(self):
return ('\n' + '%11s' *
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
@ -73,9 +76,6 @@ class ClassificationTrainer(BaseTrainer):
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
return loss, loss
def check_resume(self):
pass
def resume_training(self, ckpt):
pass
@ -85,10 +85,13 @@ class ClassificationTrainer(BaseTrainer):
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
def train(cfg):
cfg.model = cfg.model or "resnet18"
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
if __name__ == "__main__":

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.00 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.67 # scales module repeats
width_multiple: 0.75 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 0.33 # scales module repeats
width_multiple: 0.50 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]

View File

@ -0,0 +1,23 @@
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
# Parameters
nc: 1000 # number of classes
depth_multiple: 1.00 # scales module repeats
width_multiple: 1.25 # scales convolution channels
# YOLOv8.0n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
# YOLOv8.0n head
head:
- [-1, 1, Classify, [nc]]