YOLOv8 architecture updates from R&D branch (#88)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -18,10 +18,7 @@ from ..detect import DetectionTrainer
|
||||
class SegmentationTrainer(DetectionTrainer):
|
||||
|
||||
def load_model(self, model_cfg=None, weights=None):
|
||||
model = SegmentationModel(model_cfg or weights["model"].yaml,
|
||||
ch=3,
|
||||
nc=self.data["nc"],
|
||||
anchors=self.args.get("anchors"))
|
||||
model = SegmentationModel(model_cfg or weights["model"].yaml, ch=3, nc=self.data["nc"])
|
||||
if weights:
|
||||
model.load(weights)
|
||||
for _, v in model.named_parameters():
|
||||
@ -29,7 +26,7 @@ class SegmentationTrainer(DetectionTrainer):
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss'
|
||||
self.loss_names = 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss'
|
||||
return v8.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
@ -235,7 +232,7 @@ class SegmentationTrainer(DetectionTrainer):
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
cfg.model = cfg.model or "models/yolov5n-seg.yaml"
|
||||
cfg.model = cfg.model or "models/yolov8n-seg.yaml"
|
||||
cfg.data = cfg.data or "coco128-seg.yaml" # or yolo.ClassificationDataset("mnist")
|
||||
trainer = SegmentationTrainer(cfg)
|
||||
trainer.train()
|
||||
@ -244,7 +241,7 @@ def train(cfg):
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
CLI usage:
|
||||
python ultralytics/yolo/v8/segment/train.py model=yolov5n-seg.yaml data=coco128-segments epochs=100 imgsz=640
|
||||
python ultralytics/yolo/v8/segment/train.py model=yolov8n-seg.yaml data=coco128-segments epochs=100 imgsz=640
|
||||
|
||||
TODO:
|
||||
Direct cli support, i.e, yolov8 classify_train args.epochs 10
|
||||
|
Reference in New Issue
Block a user