From 384f0ef1c67829c444de4e2ac19542b84d31acdc Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Wed, 28 Dec 2022 18:05:01 +0530 Subject: [PATCH] Model interface enhancement (#106) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- .github/workflows/ci.yaml | 6 +++--- tests/test_model.py | 22 +++++++++++----------- ultralytics/yolo/engine/model.py | 11 +++++++++-- ultralytics/yolo/engine/predictor.py | 13 +++++++++---- ultralytics/yolo/engine/trainer.py | 7 ++++++- ultralytics/yolo/utils/configs/__init__.py | 2 +- 6 files changed, 39 insertions(+), 22 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7b70533..77ec9dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -91,14 +91,14 @@ jobs: shell: bash # for Windows compatibility run: | yolo task=detect mode=train model=yolov8n.yaml data=coco128.yaml epochs=1 imgsz=64 - yolo task=detect mode=val model=runs/train/exp/weights/last.pt imgsz=64 + yolo task=detect mode=val model=runs/detect/train/weights/last.pt imgsz=64 - name: Test segmentation # TODO: segmentation CI shell: bash # for Windows compatibility run: | # yolo task=segment mode=train model=yolov8n-seg.yaml data=coco128-seg.yaml epochs=1 imgsz=64 - # yolo task=segment mode=val model=runs/train/exp2/weights/last.pt data=coco128-seg.yaml imgsz=64 + # yolo task=segment mode=val model=runs/segment/train/weights/last.pt data=coco128-seg.yaml imgsz=64 - name: Test classification # TODO: change to exp3 on Segmentation CI update shell: bash # for Windows compatibility run: | yolo task=classify mode=train model=resnet18 data=mnist160 epochs=1 imgsz=32 - yolo task=classify mode=val model=runs/train/exp2/weights/last.pt data=mnist160 + yolo task=classify mode=val model=runs/classify/train/weights/last.pt data=mnist160 diff --git a/tests/test_model.py b/tests/test_model.py index 306f11f..3146d81 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,11 +1,11 @@ import torch -from ultralytics.yolo import YOLO +from ultralytics import YOLO def test_model_forward(): model = YOLO() - model.new("yolov8n-seg.yaml") + model.new("yolov8n.yaml") img = torch.rand(512 * 512 * 3).view(1, 3, 512, 512) model.forward(img) model(img) @@ -15,7 +15,7 @@ def test_model_info(): model = YOLO() model.new("yolov8n.yaml") model.info() - model.load("balloon-detect.pt") + model.load("best.pt") model.info(verbose=True) @@ -23,35 +23,35 @@ def test_model_fuse(): model = YOLO() model.new("yolov8n.yaml") model.fuse() - model.load("balloon-detect.pt") + model.load("best.pt") model.fuse() def test_visualize_preds(): model = YOLO() - model.load("balloon-segment.pt") + model.load("best.pt") model.predict(source="ultralytics/assets") def test_val(): model = YOLO() - model.load("balloon-segment.pt") - model.val(data="coco128-seg.yaml", imgsz=32) + model.load("best.pt") + model.val(data="coco128.yaml", imgsz=32) def test_model_resume(): model = YOLO() - model.new("yolov8n-seg.yaml") - model.train(epochs=1, imgsz=32, data="coco128-seg.yaml") + model.new("yolov8n.yaml") + model.train(epochs=1, imgsz=32, data="coco128.yaml") try: - model.resume(task="segment") + model.resume(task="detect") except AssertionError: print("Successfully caught resume assert!") def test_model_train_pretrained(): model = YOLO() - model.load("balloon-detect.pt") + model.load("best.pt") model.train(data="coco128.yaml", epochs=1, imgsz=32) model.new("yolov8n.yaml") model.train(data="coco128.yaml", epochs=1, imgsz=32) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 1ecee36..e7720e7 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -43,6 +43,7 @@ class YOLO: self.trainer = None self.task = None self.ckpt = None + self.overrides = {} def new(self, cfg: str): """ @@ -69,6 +70,10 @@ class YOLO: """ self.ckpt = torch.load(weights, map_location="cpu") self.task = self.ckpt["train_args"]["task"] + self.overrides = dict(self.ckpt["train_args"]) + self.overrides["device"] = '' # reset device + LOGGER.info("Device has been reset to ''") + self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._guess_ops_from_task( task=self.task) self.model = attempt_load_weights(weights) @@ -107,6 +112,7 @@ class YOLO: source (str): Accepts all source types accepted by yolo **kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs """ + kwargs.update(self.overrides) predictor = self.PredictorClass(overrides=kwargs) # check size type @@ -119,7 +125,7 @@ class YOLO: predictor.setup(model=self.model, source=source) predictor() - def val(self, data, **kwargs): + def val(self, data=None, **kwargs): """ Validate a model on a given dataset @@ -130,8 +136,9 @@ class YOLO: if not self.model: raise Exception("model not initialized!") + kwargs.update(self.overrides) args = get_config(config=DEFAULT_CONFIG, overrides=kwargs) - args.data = data + args.data = data or args.data args.task = self.task validator = self.ValidatorClass(args=args) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 21c55d6..94b2b72 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -86,10 +86,15 @@ class BasePredictor: # data if self.data: - if self.data.endswith(".yaml"): - self.data = check_dataset_yaml(self.data) - else: - self.data = check_dataset(self.data) + try: + if self.data.endswith(".yaml"): + self.data = check_dataset_yaml(self.data) + else: + self.data = check_dataset(self.data) + except AssertionError as e: + LOGGER.info(f"Error ocurred: {e}") + finally: + LOGGER.info("Predictor will continue without reading the dataset") # model device = select_device(self.args.device) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 6757f3f..d0524ea 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -46,10 +46,15 @@ class BaseTrainer: self.validator = None self.model = None self.callbacks = defaultdict(list) - self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) + + # dirs + project = overrides.get("project") or self.args.task + name = overrides.get("name") or self.args.mode + self.save_dir = increment_path(Path("runs") / project / name, exist_ok=self.args.exist_ok) self.wdir = self.save_dir / 'weights' # weights dir self.wdir.mkdir(parents=True, exist_ok=True) # make dir self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths + self.batch_size = self.args.batch_size self.epochs = self.args.epochs self.start_epoch = 0 diff --git a/ultralytics/yolo/utils/configs/__init__.py b/ultralytics/yolo/utils/configs/__init__.py index e2ce966..a1e6cf9 100644 --- a/ultralytics/yolo/utils/configs/__init__.py +++ b/ultralytics/yolo/utils/configs/__init__.py @@ -6,7 +6,7 @@ from omegaconf import DictConfig, OmegaConf from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch -def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict]): +def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): """ Accepts yaml file name or DictConfig containing experiment configuration. Returns training args namespace