Update `check_version()` for inequality support (#4182)

single_channel
Glenn Jocher 1 year ago committed by GitHub
parent 038558cfab
commit eb80ba9ce0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,6 +39,7 @@ setup(
install_requires=REQUIREMENTS, install_requires=REQUIREMENTS,
extras_require={ extras_require={
'dev': [ 'dev': [
'ipython',
'check-manifest', 'check-manifest',
'pytest', 'pytest',
'pytest-cov', 'pytest-cov',

@ -252,6 +252,8 @@ def handle_yolo_settings(args: List[str]) -> None:
python my_script.py yolo settings reset python my_script.py yolo settings reset
``` ```
""" """
url = 'https://docs.ultralytics.com/quickstart/#ultralytics-settings' # help URL
try:
if any(args): if any(args):
if args[0] == 'reset': if args[0] == 'reset':
SETTINGS_YAML.unlink() # delete the settings file SETTINGS_YAML.unlink() # delete the settings file
@ -262,7 +264,10 @@ def handle_yolo_settings(args: List[str]) -> None:
check_dict_alignment(SETTINGS, new) check_dict_alignment(SETTINGS, new)
SETTINGS.update(new) SETTINGS.update(new)
LOGGER.info(f'💡 Learn about settings at {url}')
yaml_print(SETTINGS_YAML) # print the current settings yaml_print(SETTINGS_YAML) # print the current settings
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ settings error: '{e}'. Please see {url} for help.")
def parse_key_value_pair(pair): def parse_key_value_pair(pair):

@ -18,7 +18,7 @@ except (ImportError, AssertionError):
def on_pretrain_routine_end(trainer): def on_pretrain_routine_end(trainer):
"""Logs training parameters to MLflow.""" """Logs training parameters to MLflow."""
global mlflow, run, run_id, experiment_name global mlflow, run, experiment_name
if os.environ.get('MLFLOW_TRACKING_URI') is None: if os.environ.get('MLFLOW_TRACKING_URI') is None:
mlflow = None mlflow = None
@ -39,8 +39,7 @@ def on_pretrain_routine_end(trainer):
run, active_run = mlflow, mlflow.active_run() run, active_run = mlflow, mlflow.active_run()
if not active_run: if not active_run:
active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name) active_run = mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name)
run_id = active_run.info.run_id LOGGER.info(f'{prefix}Using run_id({active_run.info.run_id}) at {mlflow_location}')
LOGGER.info(f'{prefix}Using run_id({run_id}) at {mlflow_location}')
run.log_params(vars(trainer.model.args)) run.log_params(vars(trainer.model.args))
except Exception as err: except Exception as err:
LOGGER.error(f'{prefix}Failing init - {repr(err)}') LOGGER.error(f'{prefix}Failing init - {repr(err)}')

@ -90,54 +90,60 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
def check_version(current: str = '0.0.0', def check_version(current: str = '0.0.0',
minimum: str = '0.0.0', required: str = '0.0.0',
maximum: str = None,
name: str = 'version ', name: str = 'version ',
pinned: bool = False,
hard: bool = False, hard: bool = False,
verbose: bool = False) -> bool: verbose: bool = False) -> bool:
""" """
Check current version against the required minimum and/or maximum version. Check current version against the required version or range.
Args: Args:
current (str): Current version. current (str): Current version.
minimum (str): Required minimum version. required (str): Required version or range (in pip-style format).
maximum (str, optional): Required maximum version.
name (str): Name to be used in warning message. name (str): Name to be used in warning message.
pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied. hard (bool): If True, raise an AssertionError if the requirement is not met.
hard (bool): If True, raise an AssertionError if the minimum or maximum version is not met. verbose (bool): If True, print warning message if requirement is not met.
verbose (bool): If True, print warning message if minimum or maximum version is not met.
Returns: Returns:
(bool): True if minimum and maximum versions are met, False otherwise. (bool): True if requirement is met, False otherwise.
Example: Example:
```python # check if current version is exactly 22.04
# Check if current version is exactly 22.04 check_version(current='22.04', required='==22.04')
check_version(current='22.04', minimum='22.04', pinned=True)
# Check if current version is greater than or equal to 22.04 # check if current version is greater than or equal to 22.04
check_version(current='22.10', minimum='22.04') check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed
# Check if current version is less than or equal to 22.04 # check if current version is less than or equal to 22.04
check_version(current='22.04', maximum='22.04') check_version(current='22.04', required='<=22.04')
# Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) # check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
check_version(current='21.10', minimum='20.04', maximum='22.04') check_version(current='21.10', required='>20.04,<22.04')
```
""" """
current = pkg.parse_version(current) current = pkg.parse_version(current)
minimum = pkg.parse_version(minimum) constraints = re.findall(r'([<>!=]{1,2}\s*\d+\.\d+)', required) or [f'>={required}']
maximum = pkg.parse_version(maximum) if maximum else None
if pinned: result = True
result = (current == minimum) for constraint in constraints:
else: op, version = re.match(r'([<>!=]{1,2})\s*(\d+\.\d+)', constraint).groups()
result = (current >= minimum) and (current <= maximum if maximum else True) version = pkg.parse_version(version)
version_message = f'a version between {minimum} and {maximum}' if maximum else f'a minimum version {minimum}' if op == '==' and current != version:
warning_message = f'WARNING ⚠️ {name} requires {version_message}, but {name}{current} is currently installed.' result = False
elif op == '!=' and current == version:
result = False
elif op == '>=' and not (current >= version):
result = False
elif op == '<=' and not (current <= version):
result = False
elif op == '>' and not (current > version):
result = False
elif op == '<' and not (current < version):
result = False
if not result:
warning_message = f'WARNING ⚠️ {name}{required} is required, but {name}{current} is currently installed'
if hard: if hard:
assert result, emojis(warning_message) # assert version requirements met raise ModuleNotFoundError(emojis(warning_message)) # assert version requirements met
if verbose and not result: if verbose:
LOGGER.warning(warning_message) LOGGER.warning(warning_message)
return result return result

Loading…
Cancel
Save