Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia
2023-01-08 00:34:34 +05:30
committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
14 changed files with 199 additions and 71 deletions

View File

@ -103,13 +103,9 @@ class YOLO:
Args:
verbose (bool): Controls verbosity.
"""
if not self.model:
LOGGER.info("model not initialized!")
self.model.info(verbose=verbose)
def fuse(self):
if not self.model:
LOGGER.info("model not initialized!")
self.model.fuse()
@smart_inference_mode()
@ -139,9 +135,6 @@ class YOLO:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
if not self.model:
raise ModuleNotFoundError("model not initialized!")
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides["mode"] = "val"
@ -177,8 +170,6 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
"""
if not self.model:
raise AttributeError("model not initialized. Use .new() or .load()")
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
@ -193,10 +184,8 @@ class YOLO:
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None,
cfg=self.model.yaml if self.task != "classify" else None)
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
def to(self, device):