ultralytics 8.0.19
seg/det dataset warning and DDP-cls/seg fixes (#595)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: 曾逸夫(Zeng Yifu) <41098760+Zengyf-CVer@users.noreply.github.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
This commit is contained in:
@ -13,11 +13,11 @@ from ultralytics.yolo.utils.torch_utils import strip_optimizer
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
|
||||
def __init__(self, config=DEFAULT_CFG, overrides=None):
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
super().__init__(config, overrides)
|
||||
super().__init__(cfg, overrides)
|
||||
|
||||
def set_model_attributes(self):
|
||||
self.model.names = self.data["names"]
|
||||
|
@ -47,7 +47,7 @@ class ClassificationValidator(BaseValidator):
|
||||
|
||||
def val(cfg=DEFAULT_CFG):
|
||||
cfg.model = cfg.model or "yolov8n-cls.pt" # or "resnet18"
|
||||
cfg.data = cfg.data or "imagenette160"
|
||||
cfg.data = cfg.data or "mnist160"
|
||||
validator = ClassificationValidator(args=cfg)
|
||||
validator(model=cfg.model)
|
||||
|
||||
|
Reference in New Issue
Block a user