Change class depending on dataset in model interface (#77)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ultralytics import yolo
|
||||
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
@ -146,7 +146,7 @@ class YOLO:
|
||||
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
||||
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
||||
"""
|
||||
if not self.model and not self.ckpt:
|
||||
if not self.model:
|
||||
raise Exception("model not initialized. Use .new() or .load()")
|
||||
|
||||
overrides = kwargs
|
||||
@ -159,8 +159,10 @@ class YOLO:
|
||||
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
# load pre-trained weights if found, else use the loaded model
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||
model_cfg=self.model.yaml if self.task != "classify" else None)
|
||||
self.model = self.trainer.model # override here to save memory
|
||||
|
||||
self.trainer.train()
|
||||
|
||||
def resume(self, task=None, model=None):
|
||||
@ -199,6 +201,9 @@ class YOLO:
|
||||
|
||||
return task
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
def _guess_ops_from_task(self, task):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
||||
# warning: eval is unsafe. Use with caution
|
||||
|
Reference in New Issue
Block a user