|
|
|
@ -1,8 +1,8 @@
|
|
|
|
|
import torch
|
|
|
|
|
import yaml
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
|
|
from ultralytics import yolo
|
|
|
|
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
|
|
|
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
|
|
|
|
from ultralytics.yolo.utils import LOGGER
|
|
|
|
|
from ultralytics.yolo.utils.checks import check_yaml
|
|
|
|
@ -146,7 +146,7 @@ class YOLO:
|
|
|
|
|
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
|
|
|
|
|
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
|
|
|
|
|
"""
|
|
|
|
|
if not self.model and not self.ckpt:
|
|
|
|
|
if not self.model:
|
|
|
|
|
raise Exception("model not initialized. Use .new() or .load()")
|
|
|
|
|
|
|
|
|
|
overrides = kwargs
|
|
|
|
@ -159,8 +159,10 @@ class YOLO:
|
|
|
|
|
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
|
|
|
|
|
|
|
|
|
|
self.trainer = self.TrainerClass(overrides=overrides)
|
|
|
|
|
# load pre-trained weights if found, else use the loaded model
|
|
|
|
|
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
|
|
|
|
|
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
|
|
|
|
model_cfg=self.model.yaml if self.task != "classify" else None)
|
|
|
|
|
self.model = self.trainer.model # override here to save memory
|
|
|
|
|
|
|
|
|
|
self.trainer.train()
|
|
|
|
|
|
|
|
|
|
def resume(self, task=None, model=None):
|
|
|
|
@ -199,6 +201,9 @@ class YOLO:
|
|
|
|
|
|
|
|
|
|
return task
|
|
|
|
|
|
|
|
|
|
def to(self, device):
|
|
|
|
|
self.model.to(device)
|
|
|
|
|
|
|
|
|
|
def _guess_ops_from_task(self, task):
|
|
|
|
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
|
|
|
|
|
# warning: eval is unsafe. Use with caution
|
|
|
|
|