Model interface enhancement (#106)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
single_channel
Ayush Chaurasia 2 years ago committed by GitHub
parent 38d6df55cb
commit 384f0ef1c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,14 +91,14 @@ jobs:
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64
yolo task=detect mode=val model=runs/train/exp/weights/last.pt imgsz=64 yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64
- name: Test segmentation # TODO: segmentation CI - name: Test segmentation # TODO: segmentation CI
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
# yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 # yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64
# yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64 # yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64
- name: Test classification # TODO: change to exp3 on Segmentation CI update - name: Test classification # TODO: change to exp3 on Segmentation CI update
shell: bash # for Windows compatibility shell: bash # for Windows compatibility
run: | run: |
yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32
yolo task=classify mode=val model=runs/train/exp2/weights/last.pt data=mnist160 yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160

@ -1,11 +1,11 @@
import torch import torch
from ultralytics.yolo import YOLO from ultralytics import YOLO
def test_model_forward(): def test_model_forward():
model = YOLO() model = YOLO()
model.new("yolov8n-seg.yaml") model.new("yolov8n.yaml")
img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512)
model.forward(img) model.forward(img)
model(img) model(img)
@ -15,7 +15,7 @@ def test_model_info():
model = YOLO() model = YOLO()
model.new("yolov8n.yaml") model.new("yolov8n.yaml")
model.info() model.info()
model.load("balloon-detect.pt") model.load("best.pt")
model.info(verbose=True) model.info(verbose=True)
@ -23,35 +23,35 @@ def test_model_fuse():
model = YOLO() model = YOLO()
model.new("yolov8n.yaml") model.new("yolov8n.yaml")
model.fuse() model.fuse()
model.load("balloon-detect.pt") model.load("best.pt")
model.fuse() model.fuse()
def test_visualize_preds(): def test_visualize_preds():
model = YOLO() model = YOLO()
model.load("balloon-segment.pt") model.load("best.pt")
model.predict(source="ultralytics/assets") model.predict(source="ultralytics/assets")
def test_val(): def test_val():
model = YOLO() model = YOLO()
model.load("balloon-segment.pt") model.load("best.pt")
model.val(data="coco128-seg.yaml", imgsz=32) model.val(data="coco128.yaml", imgsz=32)
def test_model_resume(): def test_model_resume():
model = YOLO() model = YOLO()
model.new("yolov8n-seg.yaml") model.new("yolov8n.yaml")
model.train(epochs=1, imgsz=32, data="coco128-seg.yaml") model.train(epochs=1, imgsz=32, data="coco128.yaml")
try: try:
model.resume(task="segment") model.resume(task="detect")
except AssertionError: except AssertionError:
print("Successfully caught resume assert!") print("Successfully caught resume assert!")
def test_model_train_pretrained(): def test_model_train_pretrained():
model = YOLO() model = YOLO()
model.load("balloon-detect.pt") model.load("best.pt")
model.train(data="coco128.yaml", epochs=1, imgsz=32) model.train(data="coco128.yaml", epochs=1, imgsz=32)
model.new("yolov8n.yaml") model.new("yolov8n.yaml")
model.train(data="coco128.yaml", epochs=1, imgsz=32) model.train(data="coco128.yaml", epochs=1, imgsz=32)

@ -43,6 +43,7 @@ class YOLO:
self.trainer = None self.trainer = None
self.task = None self.task = None
self.ckpt = None self.ckpt = None
self.overrides = {}
def new(self, cfg: str): def new(self, cfg: str):
""" """
@ -69,6 +70,10 @@ class YOLO:
""" """
self.ckpt = torch.load(weights, map_location="cpu") self.ckpt = torch.load(weights, map_location="cpu")
self.task = self.ckpt["train_args"]["task"] self.task = self.ckpt["train_args"]["task"]
self.overrides = dict(self.ckpt["train_args"])
self.overrides["device"] = '' # reset device
LOGGER.info("Device has been reset to ''")
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task(
task=self.task) task=self.task)
self.model = attempt_load_weights(weights) self.model = attempt_load_weights(weights)
@ -107,6 +112,7 @@ class YOLO:
source (str): Accepts all source types accepted by yolo source (str): Accepts all source types accepted by yolo
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs **kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
""" """
kwargs.update(self.overrides)
predictor = self.PredictorClass(overrides=kwargs) predictor = self.PredictorClass(overrides=kwargs)
# check size type # check size type
@ -119,7 +125,7 @@ class YOLO:
predictor.setup(model=self.model, source=source) predictor.setup(model=self.model, source=source)
predictor() predictor()
def val(self, data, **kwargs): def val(self, data=None, **kwargs):
""" """
Validate a model on a given dataset Validate a model on a given dataset
@ -130,8 +136,9 @@ class YOLO:
if not self.model: if not self.model:
raise Exception("model not initialized!") raise Exception("model not initialized!")
kwargs.update(self.overrides)
args = get_config(config=DEFAULT_CONFIG, overrides=kwargs) args = get_config(config=DEFAULT_CONFIG, overrides=kwargs)
args.data = data args.data = data or args.data
args.task = self.task args.task = self.task
validator = self.ValidatorClass(args=args) validator = self.ValidatorClass(args=args)

@ -86,10 +86,15 @@ class BasePredictor:
# data # data
if self.data: if self.data:
if self.data.endswith(".yaml"): try:
self.data = check_dataset_yaml(self.data) if self.data.endswith(".yaml"):
else: self.data = check_dataset_yaml(self.data)
self.data = check_dataset(self.data) else:
self.data = check_dataset(self.data)
except AssertionError as e:
LOGGER.info(f"Error ocurred: {e}")
finally:
LOGGER.info("Predictor will continue without reading the dataset")
# model # model
device = select_device(self.args.device) device = select_device(self.args.device)

@ -46,10 +46,15 @@ class BaseTrainer:
self.validator = None self.validator = None
self.model = None self.model = None
self.callbacks = defaultdict(list) self.callbacks = defaultdict(list)
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok)
# dirs
project = overrides.get("project") or self.args.task
name = overrides.get("name") or self.args.mode
self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok)
self.wdir = self.save_dir / 'weights' # weights dir self.wdir = self.save_dir / 'weights' # weights dir
self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths
self.batch_size = self.args.batch_size self.batch_size = self.args.batch_size
self.epochs = self.args.epochs self.epochs = self.args.epochs
self.start_epoch = 0 self.start_epoch = 0

@ -6,7 +6,7 @@ from omegaconf import DictConfig, OmegaConf
from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch
def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict]): def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}):
""" """
Accepts yaml file name or DictConfig containing experiment configuration. Accepts yaml file name or DictConfig containing experiment configuration.
Returns training args namespace Returns training args namespace

Loading…
Cancel
Save