Add YOLOv5 dataset yamls (#207)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -38,7 +38,7 @@ class ClassificationPredictor(BasePredictor):
|
||||
log_string += '%gx%g ' % im.shape[2:] # print string
|
||||
self.annotator = self.get_annotator(im0)
|
||||
|
||||
prob = preds[idx]
|
||||
prob = preds[idx].softmax(0)
|
||||
self.all_outputs.append(prob)
|
||||
# Print results
|
||||
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
|
||||
|
@ -25,6 +25,8 @@ class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"])
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
pretrained = False
|
||||
for m in model.modules():
|
||||
@ -35,9 +37,6 @@ class ClassificationTrainer(BaseTrainer):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
# Update defaults
|
||||
if self.args.imgsz == 640:
|
||||
self.args.imgsz = 224
|
||||
@ -68,12 +67,15 @@ class ClassificationTrainer(BaseTrainer):
|
||||
return # dont return ckpt. Classification doesn't support resume
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
return build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size if mode == "train" else (batch_size * 2),
|
||||
augment=mode == "train",
|
||||
rank=rank,
|
||||
workers=self.args.workers)
|
||||
loader = build_classification_dataloader(path=dataset_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch_size if mode == "train" else (batch_size * 2),
|
||||
augment=mode == "train",
|
||||
rank=rank,
|
||||
workers=self.args.workers)
|
||||
if mode != "train":
|
||||
self.model.transforms = loader.dataset.torch_transforms # attach inference transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device)
|
||||
@ -141,19 +143,18 @@ def train(cfg):
|
||||
cfg.weight_decay = 5e-5
|
||||
cfg.label_smoothing = 0.1
|
||||
cfg.warmup_epochs = 0.0
|
||||
trainer = ClassificationTrainer(cfg)
|
||||
trainer.train()
|
||||
# from ultralytics import YOLO
|
||||
# model = YOLO(cfg.model)
|
||||
# model.train(**cfg)
|
||||
# trainer = ClassificationTrainer(cfg)
|
||||
# trainer.train()
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.train(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
CLI usage:
|
||||
python ultralytics/yolo/v8/classify/train.py model=resnet18 data=imagenette160 epochs=1 imgsz=224
|
||||
|
||||
TODO:
|
||||
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||
yolo task=classify mode=train model=yolov8n-cls.pt data=mnist160 epochs=10 imgsz=32
|
||||
yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 imgsz=32
|
||||
yolo task=classify mode=predict model=runs/classify/train/weights/last.pt imgsz=32 source=ultralytics/assets/bus.jpg
|
||||
yolo mode=export model=runs/classify/train/weights/last.pt imgsz=32 format=torchscript
|
||||
"""
|
||||
train()
|
||||
|
Reference in New Issue
Block a user