Add YOLOv5 dataset yamls (#207)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-10 18:33:44 +01:00
committed by GitHub
parent e371e81aa0
commit c7629e93bd
14 changed files with 2018 additions and 39 deletions

View File

@ -38,7 +38,7 @@ class ClassificationPredictor(BasePredictor):
log_string += '%gx%g ' % im.shape[2:] # print string
self.annotator = self.get_annotator(im0)
prob = preds[idx]
prob = preds[idx].softmax(0)
self.all_outputs.append(prob)
# Print results
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices

View File

@ -25,6 +25,8 @@ class ClassificationTrainer(BaseTrainer):
def get_model(self, cfg=None, weights=None, verbose=True):
model = ClassificationModel(cfg, nc=self.data["nc"])
if weights:
model.load(weights)
pretrained = False
for m in model.modules():
@ -35,9 +37,6 @@ class ClassificationTrainer(BaseTrainer):
for p in model.parameters():
p.requires_grad = True # for training
if weights:
model.load(weights)
# Update defaults
if self.args.imgsz == 640:
self.args.imgsz = 224
@ -68,12 +67,15 @@ class ClassificationTrainer(BaseTrainer):
return # dont return ckpt. Classification doesn't support resume
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
return build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
batch_size=batch_size if mode == "train" else (batch_size * 2),
augment=mode == "train",
rank=rank,
workers=self.args.workers)
loader = build_classification_dataloader(path=dataset_path,
imgsz=self.args.imgsz,
batch_size=batch_size if mode == "train" else (batch_size * 2),
augment=mode == "train",
rank=rank,
workers=self.args.workers)
if mode != "train":
self.model.transforms = loader.dataset.torch_transforms # attach inference transforms
return loader
def preprocess_batch(self, batch):
batch["img"] = batch["img"].to(self.device)
@ -141,19 +143,18 @@ def train(cfg):
cfg.weight_decay = 5e-5
cfg.label_smoothing = 0.1
cfg.warmup_epochs = 0.0
trainer = ClassificationTrainer(cfg)
trainer.train()
# from ultralytics import YOLO
# model = YOLO(cfg.model)
# model.train(**cfg)
# trainer = ClassificationTrainer(cfg)
# trainer.train()
from ultralytics import YOLO
model = YOLO(cfg.model)
model.train(**cfg)
if __name__ == "__main__":
"""
CLI usage:
python ultralytics/yolo/v8/classify/train.py model=resnet18 data=imagenette160 epochs=1 imgsz=224
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10
yolo task=classify mode=train model=yolov8n-cls.pt data=mnist160 epochs=10 imgsz=32
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
"""
train()