Fix Classification train logging (#157)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
@ -2,7 +2,7 @@ import hydra
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_weights
|
||||
from ultralytics.nn.tasks import ClassificationModel, attempt_load_one_weight
|
||||
from ultralytics.yolo import v8
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.trainer import BaseTrainer
|
||||
@ -20,8 +20,18 @@ class ClassificationTrainer(BaseTrainer):
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None):
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"])
|
||||
|
||||
pretrained = False
|
||||
for m in model.modules():
|
||||
if not pretrained and hasattr(m, 'reset_parameters'):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
@ -43,7 +53,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||
model = str(self.model)
|
||||
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
|
||||
if model.endswith(".pt"):
|
||||
self.model = attempt_load_weights(model, device='cpu')
|
||||
self.model, _ = attempt_load_one_weight(model, device='cpu')
|
||||
elif model.endswith(".yaml"):
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
@ -54,10 +64,11 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0, mode="train"):
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size if mode == "train" else (batch_size * 2),
|
||||
augment=mode == "train",
|
||||
rank=rank)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
@ -66,15 +77,41 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return batch
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' *
|
||||
(4 + len(self.loss_names))) % ('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
return ('\n' + '%11s' * (4 + len(self.loss_names))) % \
|
||||
('Epoch', 'GPU_mem', *self.loss_names, 'Instances', 'Size')
|
||||
|
||||
def get_validator(self):
|
||||
self.loss_names = ['loss']
|
||||
return v8.classify.ClassificationValidator(self.test_loader, self.save_dir, logger=self.console)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"])
|
||||
return loss, loss
|
||||
loss = torch.nn.functional.cross_entropy(preds, batch["cls"], reduction='sum') / self.args.nbs
|
||||
loss_items = loss.detach()
|
||||
return loss, loss_items
|
||||
|
||||
# def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# """
|
||||
# Returns a loss dict with labelled training loss items tensor
|
||||
# """
|
||||
# # Not needed for classification but necessary for segmentation & detection
|
||||
# keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
# if loss_items is not None:
|
||||
# loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
# return dict(zip(keys, loss_items))
|
||||
# else:
|
||||
# return keys
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
else:
|
||||
return keys
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
pass
|
||||
@ -86,12 +123,16 @@ class ClassificationTrainer(BaseTrainer):
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "yolov8n-cls.yaml" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
|
||||
# trainer = ClassificationTrainer(cfg)
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
cfg.data = cfg.data or "mnist160" # or yolo.ClassificationDataset("mnist")
|
||||
cfg.lr0 = 0.1
|
||||
cfg.weight_decay = 5e-5
|
||||
cfg.label_smoothing = 0.1
|
||||
cfg.warmup_epochs = 0.0
|
||||
trainer = ClassificationTrainer(cfg)
|
||||
trainer.train()
|
||||
# from ultralytics import YOLO
|
||||
# model = YOLO(cfg.model)
|
||||
# model.train(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,5 +1,4 @@
|
||||
import hydra
|
||||
import torch
|
||||
|
||||
from ultralytics.yolo.data import build_classification_dataloader
|
||||
from ultralytics.yolo.engine.validator import BaseValidator
|
||||
@ -13,8 +12,12 @@ class ClassificationValidator(BaseValidator):
|
||||
super().__init__(dataloader, save_dir, pbar, logger, args)
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self):
|
||||
return ('%22s' + '%11s' * 2) % ('classes', 'top1_acc', 'top5_acc')
|
||||
|
||||
def init_metrics(self, model):
|
||||
self.correct = torch.tensor([], device=next(model.parameters()).device)
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
|
||||
def preprocess(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True)
|
||||
@ -23,17 +26,20 @@ class ClassificationValidator(BaseValidator):
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
targets = batch["cls"]
|
||||
correct_in_batch = (targets[:, None] == preds).float()
|
||||
self.correct = torch.cat((self.correct, correct_in_batch))
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :5])
|
||||
self.targets.append(batch["cls"])
|
||||
|
||||
def get_stats(self):
|
||||
self.metrics.process(self.correct)
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
return build_classification_dataloader(path=dataset_path, imgsz=self.args.imgsz, batch_size=batch_size)
|
||||
|
||||
def print_results(self):
|
||||
pf = '%22s' + '%11.3g' * len(self.metrics.keys) # print format
|
||||
self.logger.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=str(DEFAULT_CONFIG.parent), config_name=DEFAULT_CONFIG.name)
|
||||
def val(cfg):
|
||||
|
Reference in New Issue
Block a user