Change class depending on dataset in model interface (#77)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 24a7c068ad
commit 48c95ba083
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -49,6 +49,16 @@ def test_model_resume():
print("Successfully caught resume assert!")
def test_model_train_pretrained():
model = YOLO()
model.load("balloon-detect.pt")
model.train(data="coco128.yaml", epochs=1, img_size=32)
model.new("yolov5n.yaml")
model.train(data="coco128.yaml", epochs=1, img_size=32)
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
model(img)
def test():
test_model_forward()
test_model_info()
@ -56,6 +66,7 @@ def test():
test_visualize_preds()
test_val()
test_model_resume()
test_model_train_pretrained()
if __name__ == "__main__":

@ -1,8 +1,8 @@
import torch
import yaml
from omegaconf import OmegaConf
from ultralytics import yolo
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.checks import check_yaml
@ -146,7 +146,7 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration. List of all args can be found in 'config' section.
You can pass all arguments as a yaml file in `cfg`. Other args are ignored if `cfg` file is passed
"""
if not self.model and not self.ckpt:
if not self.model:
raise Exception("model not initialized. Use .new() or .load()")
overrides = kwargs
@ -159,8 +159,10 @@ class YOLO:
raise Exception("dataset not provided! Please check if you have defined `data` in you configs")
self.trainer = self.TrainerClass(overrides=overrides)
# load pre-trained weights if found, else use the loaded model
self.trainer.model = self.trainer.load_model(weights=self.ckpt) if self.ckpt else self.model
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
model_cfg=self.model.yaml if self.task != "classify" else None)
self.model = self.trainer.model # override here to save memory
self.trainer.train()
def resume(self, task=None, model=None):
@ -199,6 +201,9 @@ class YOLO:
return task
def to(self, device):
self.model.to(device)
def _guess_ops_from_task(self, task):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task]
# warning: eval is unsafe. Use with caution

Loading…
Cancel
Save