From 681cfc1c35f06c25f3842e108623db9b5df045b0 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Tue, 20 Dec 2022 09:50:08 +0530 Subject: [PATCH] Make config overrides user friendly (#80) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher --- ultralytics/yolo/utils/configs/__init__.py | 4 + ultralytics/yolo/utils/configs/hydra_patch.py | 79 +++++++++++++++++++ ultralytics/yolo/v8/__init__.py | 3 + 3 files changed, 86 insertions(+) create mode 100644 ultralytics/yolo/utils/configs/hydra_patch.py diff --git a/ultralytics/yolo/utils/configs/__init__.py b/ultralytics/yolo/utils/configs/__init__.py index d06463b..44cb09f 100644 --- a/ultralytics/yolo/utils/configs/__init__.py +++ b/ultralytics/yolo/utils/configs/__init__.py @@ -3,6 +3,8 @@ from typing import Dict, Union from omegaconf import DictConfig, OmegaConf +from ultralytics.yolo.utils.configs.hydra_patch import check_config_mismatch + def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): """ @@ -20,4 +22,6 @@ def get_config(config: Union[str, DictConfig], overrides: Union[str, Dict] = {}) elif isinstance(overrides, Dict): overrides = OmegaConf.create(overrides) + check_config_mismatch(dict(overrides).keys(), dict(config).keys()) + return OmegaConf.merge(config, overrides) diff --git a/ultralytics/yolo/utils/configs/hydra_patch.py b/ultralytics/yolo/utils/configs/hydra_patch.py new file mode 100644 index 0000000..d381697 --- /dev/null +++ b/ultralytics/yolo/utils/configs/hydra_patch.py @@ -0,0 +1,79 @@ +import sys +from difflib import get_close_matches +from textwrap import dedent + +import hydra +from hydra.errors import ConfigCompositionException +from omegaconf import OmegaConf, open_dict +from omegaconf.errors import ConfigAttributeError, ConfigKeyError, OmegaConfBaseException + +from ultralytics.yolo.utils import LOGGER, colorstr + + +def override_config(overrides, cfg): + override_keys = [override.key_or_group for override in overrides] + check_config_mismatch(override_keys, cfg.keys()) + for override in overrides: + if override.package is not None: + raise ConfigCompositionException(f"Override {override.input_line} looks like a config group" + f" override, but config group '{override.key_or_group}' does not" + " exist.") + + key = override.key_or_group + value = override.value() + try: + if override.is_delete(): + config_val = OmegaConf.select(cfg, key, throw_on_missing=False) + if config_val is None: + raise ConfigCompositionException(f"Could not delete from config. '{override.key_or_group}'" + " does not exist.") + elif value is not None and value != config_val: + raise ConfigCompositionException("Could not delete from config. The value of" + f" '{override.key_or_group}' is {config_val} and not" + f" {value}.") + + last_dot = key.rfind(".") + with open_dict(cfg): + if last_dot == -1: + del cfg[key] + else: + node = OmegaConf.select(cfg, key[0:last_dot]) + del node[key[last_dot + 1:]] + + elif override.is_add(): + if OmegaConf.select(cfg, key, throw_on_missing=False) is None or isinstance(value, (dict, list)): + OmegaConf.update(cfg, key, value, merge=True, force_add=True) + else: + assert override.input_line is not None + raise ConfigCompositionException( + dedent(f"""\ + Could not append to config. An item is already at '{override.key_or_group}'. + Either remove + prefix: '{override.input_line[1:]}' + Or add a second + to add or override '{override.key_or_group}': '+{override.input_line}' + """)) + elif override.is_force_add(): + OmegaConf.update(cfg, key, value, merge=True, force_add=True) + else: + try: + OmegaConf.update(cfg, key, value, merge=True) + except (ConfigAttributeError, ConfigKeyError) as ex: + raise ConfigCompositionException(f"Could not override '{override.key_or_group}'." + f"\nTo append to your config use +{override.input_line}") from ex + except OmegaConfBaseException as ex: + raise ConfigCompositionException(f"Error merging override {override.input_line}").with_traceback( + sys.exc_info()[2]) from ex + + +def check_config_mismatch(overrides, cfg): + mismatched = [] + for option in overrides: + if option not in cfg and 'hydra.' not in option: + mismatched.append(option) + + for option in mismatched: + LOGGER.info(f"{colorstr(option)} is not a valid key. Similar keys: {get_close_matches(option, cfg, 3, 0.6)}") + if mismatched: + exit() + + +hydra._internal.config_loader_impl.ConfigLoaderImpl._apply_overrides_to_config = override_config diff --git a/ultralytics/yolo/v8/__init__.py b/ultralytics/yolo/v8/__init__.py index cec773e..97834a1 100644 --- a/ultralytics/yolo/v8/__init__.py +++ b/ultralytics/yolo/v8/__init__.py @@ -5,3 +5,6 @@ from ultralytics.yolo.v8 import classify, detect, segment ROOT = Path(__file__).parents[0] # yolov8 ROOT __all__ = ["classify", "segment", "detect"] + +# Patch hydra cli +from ultralytics.yolo.utils.configs import hydra_patch