ultralytics 8.0.19 seg/det dataset warning and DDP-cls/seg fixes (#595)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2023-01-24 23:22:02 +01:00
committed by GitHub
parent 936414c615
commit 520825c4b2
21 changed files with 103 additions and 63 deletions

View File

@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer
class ClassificationTrainer(BaseTrainer):
def __init__(self, config=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "classify"
super().__init__(config, overrides)
super().__init__(cfg, overrides)
def set_model_attributes(self):
self.model.names = self.data["names"]

View File

@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator):
def val(cfg=DEFAULT_CFG):
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
cfg.data = cfg.data or "imagenette160"
cfg.data = cfg.data or "mnist160"
validator = ClassificationValidator(args=cfg)
validator(model=cfg.model)

View File

@ -18,11 +18,11 @@ from ultralytics.yolo.v8.detect.train import Loss
# BaseTrainer python usage
class SegmentationTrainer(v8.detect.DetectionTrainer):
def __init__(self, config=DEFAULT_CFG, overrides=None):
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(config, overrides)
super().__init__(cfg, overrides)
def get_model(self, cfg=None, weights=None, verbose=True):
model = SegmentationModel(cfg, ch=3, nc=self.data["nc"], verbose=verbose)