YOLOv8 architecture updates from R&D branch (#88)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ultralytics import yolo
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
# from ultralytics import yolo
|
||||
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
@ -147,7 +147,7 @@ class YOLO:
|
||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||
"""
|
||||
if not self.model:
|
||||
raise Exception("model not initialized. Use .new() or .load()")
|
||||
raise AttributeError("model not initialized. Use .new() or .load()")
|
||||
|
||||
overrides = kwargs
|
||||
if kwargs.get("cfg"):
|
||||
@ -156,7 +156,7 @@ class YOLO:
|
||||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
||||
raise AttributeError("dataset not provided! Please check if you have defined `data` in you configs")
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||
@ -175,14 +175,15 @@ class YOLO:
|
||||
"""
|
||||
if task:
|
||||
if task.lower() not in MODEL_MAP:
|
||||
raise Exception(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
||||
raise SyntaxError(f"unrecognised task - {task}. Supported tasks are {MODEL_MAP.keys()}")
|
||||
else:
|
||||
ckpt = torch.load(model, map_location="cpu")
|
||||
task = ckpt["train_args"]["task"]
|
||||
del ckpt
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
|
||||
task=task.lower())
|
||||
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model if model else True})
|
||||
self.trainer = self.TrainerClass(overrides={"task": task.lower(), "resume": model or True})
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
@staticmethod
|
||||
@ -196,8 +197,7 @@ class YOLO:
|
||||
task = "segment"
|
||||
|
||||
if not task:
|
||||
raise Exception(
|
||||
"task or model not recognized! Please refer the docs at : ") # TODO: add gitHub and docs links
|
||||
raise SyntaxError("task or model not recognized! Please refer the docs at : ") # TODO: add docs links
|
||||
|
||||
return task
|
||||
|
||||
|
Reference in New Issue
Block a user