From 55bdca6768940046db7f3dc265a056dad9976a7b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 4 Jan 2023 16:37:46 +0100 Subject: [PATCH] Improvements (#142) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/engine/exporter.py | 6 +- ultralytics/yolo/engine/predictor.py | 4 +- ultralytics/yolo/engine/trainer.py | 5 +- ultralytics/yolo/engine/validator.py | 4 +- ultralytics/yolo/utils/__init__.py | 44 ++++++++- ultralytics/yolo/utils/checks.py | 133 ++++++++++++++++++++------- 6 files changed, 150 insertions(+), 46 deletions(-) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index 4dd4bd3..49a9742 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -73,7 +73,7 @@ from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages from ultralytics.yolo.data.utils import check_dataset from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml -from ultralytics.yolo.utils.files import file_size, increment_path +from ultralytics.yolo.utils.files import file_size from ultralytics.yolo.utils.ops import Profile from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode @@ -138,10 +138,6 @@ class Exporter: if overrides is None: overrides = {} self.args = get_config(config, overrides) - project = self.args.project or f"runs/{self.args.task}" - name = self.args.name or "exp" # hardcode mode as export doesn't require it - self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) - self.save_dir.mkdir(parents=True, exist_ok=True) self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()}) # add callbacks callbacks.add_integration_callbacks(self) diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index 42130e9..cd8537c 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS -from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, ops +from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode @@ -73,7 +73,7 @@ class BasePredictor: if overrides is None: overrides = {} self.args = get_config(config, overrides) - project = self.args.project or f"runs/{self.args.task}" + project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task name = self.args.name or f"{self.args.mode}" self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 0d02c1a..4cbb4f0 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -25,7 +25,8 @@ import ultralytics.yolo.utils as utils from ultralytics import __version__ from ultralytics.yolo.configs import get_config from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks, colorstr, yaml_save +from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, + yaml_save) from ultralytics.yolo.utils.checks import check_file, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.yolo.utils.files import get_latest_run, increment_path @@ -88,7 +89,7 @@ class BaseTrainer: init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) # Dirs - project = self.args.project or f"runs/{self.args.task}" + project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task name = self.args.name or f"{self.args.mode}" self.save_dir = Path( self.args.get( diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 6715cb4..ed6fa60 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -8,7 +8,7 @@ from tqdm import tqdm from ultralytics.nn.autobackend import AutoBackend from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml -from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks +from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks from ultralytics.yolo.utils.checks import check_imgsz from ultralytics.yolo.utils.files import increment_path from ultralytics.yolo.utils.ops import Profile @@ -59,7 +59,7 @@ class BaseValidator: self.speed = None self.jdict = None - project = self.args.project or f"runs/{self.args.task}" + project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task name = self.args.name or f"{self.args.mode}" self.save_dir = save_dir or increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index be0b749..593311f 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -18,7 +18,6 @@ FILE = Path(__file__).resolve() ROOT = FILE.parents[2] # YOLO DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml" RANK = int(os.getenv('RANK', -1)) -DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf @@ -119,6 +118,41 @@ def is_docker() -> bool: return 'docker' in f.read() +def is_git_directory() -> bool: + """ + Check if the current working directory is inside a git repository. + + Returns: + bool: True if the current working directory is inside a git repository, False otherwise. + """ + from git import Repo + try: + # Check if the current working directory is a git repository + Repo(search_parent_directories=True) + return True + except Exception: + return False + + +def is_pip_package(filepath: str = __name__) -> bool: + """ + Determines if the file at the given filepath is part of a pip package. + + Args: + filepath (str): The filepath to check. + + Returns: + bool: True if the file is part of a pip package, False otherwise. + """ + import importlib.util + + # Get the spec for the module + spec = importlib.util.find_spec(filepath) + + # Return whether the spec is not None and the origin is not None (indicating it is a package) + return spec is not None and spec.origin is not None + + def is_dir_writeable(dir_path: str) -> bool: """ Check if a directory is writeable. @@ -305,10 +339,11 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): """ from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first + git_install = not is_pip_package() defaults = { - 'datasets_dir': None, # default datasets directory. If None, current working directory is used. - 'weights_dir': None, # default weights directory. If None, current working directory is used. - 'runs_dir': None, # default runs directory. If None, current working directory is used. + 'datasets_dir': str(ROOT / 'datasets') if git_install else 'datasets', # default datasets directory. + 'weights_dir': str(ROOT / 'weights') if git_install else 'weights', # default weights directory. + 'runs_dir': str(ROOT / 'runs') if git_install else 'runs', # default runs directory. 'sync': True, # sync analytics to help with YOLO development 'uuid': uuid.getnode(), # device UUID to align analytics 'yaml_file': str(file)} # setting YAML file path @@ -336,6 +371,7 @@ if platform.system() == 'Windows': # Check first-install steps SETTINGS = get_settings() +DATASETS_DIR = Path(SETTINGS['datasets_dir']) # global datasets directory def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 4699a3d..e50becb 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -1,5 +1,6 @@ import glob import inspect +import math import platform import urllib from pathlib import Path @@ -13,71 +14,141 @@ import torch from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, is_docker, is_jupyter_notebook) -from ultralytics.yolo.utils.ops import make_divisible -def is_ascii(s=''): - # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) - s = str(s) # convert list, tuple, None, etc. to str - return len(s.encode().decode('ascii', 'ignore')) == len(s) +def is_ascii(s) -> bool: + """ + Check if a string is composed of only ASCII characters. + Args: + s (str): String to be checked. -def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): - # Verify image size is a multiple of stride s in each dimension + Returns: + bool: True if the string is composed only of ASCII characters, False otherwise. + """ + # Convert list, tuple, None, etc. to string + s = str(s) + + # Check if the string is composed of only ASCII characters + return all(ord(c) < 128 for c in s) + +def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): + """ + Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the + stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. + + Args: + imgsz (int or List[int]): Image size. + stride (int): Stride value. + min_dim (int): Minimum number of dimensions. + floor (int): Minimum allowed value for image size. + + Returns: + List[int]: Updated image size. + """ + # Convert stride to integer if it is a tensor stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) - if isinstance(imgsz, int): # integer i.e. imgsz=640 - sz = max(make_divisible(imgsz, stride), floor) - else: # list i.e. imgsz=[640, 480] - imgsz = list(imgsz) # convert to list if tuple - sz = [max(make_divisible(x, stride), floor) for x in imgsz] + + # Convert image size to list if it is an integer + if isinstance(imgsz, int): + imgsz = [imgsz] + + # Make image size a multiple of the stride + sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] + + # Print warning message if image size was updated if sz != imgsz: LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}') - # Check dims - if min_dim == 2: - if isinstance(imgsz, int): - sz = [sz, sz] - elif len(sz) == 1: - sz = [sz[0], sz[0]] + # Add missing dimensions if necessary + if min_dim == 2 and len(sz) == 1: + sz = [sz[0], sz[0]] return sz -def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): - # Check version vs. required version - current, minimum = (pkg.parse_version(x) for x in (current, minimum)) +def check_version(current: str = "0.0.0", + minimum: str = "0.0.0", + name: str = "version ", + pinned: bool = False, + hard: bool = False, + verbose: bool = False) -> bool: + """ + Check current version against the required minimum version. + + Args: + current (str): Current version. + minimum (str): Required minimum version. + name (str): Name to be used in warning message. + pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied. + hard (bool): If True, raise an AssertionError if the minimum version is not met. + verbose (bool): If True, print warning message if minimum version is not met. + + Returns: + bool: True if minimum version is met, False otherwise. + """ + from pkg_resources import parse_version + current, minimum = (parse_version(x) for x in (current, minimum)) result = (current == minimum) if pinned else (current >= minimum) # bool - s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" # string + warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" if hard: - assert result, emojis(s) # assert min requirements met + assert result, emojis(warning_message) # assert min requirements met if verbose and not result: - LOGGER.warning(s) + LOGGER.warning(warning_message) return result -def check_font(font=FONT, progress=False): - # Download font to CONFIG_DIR if necessary +def check_font(font: str = FONT, progress: bool = False) -> None: + """ + Download font file to the user's configuration directory if it does not already exist. + + Args: + font (str): Path to font file. + progress (bool): If True, display a progress bar during the download. + + Returns: + None + """ font = Path(font) + + # Destination path for the font file file = USER_CONFIG_DIR / font.name + + # Check if font file exists at the source or destination path if not font.exists() and not file.exists(): + # Download font file url = f'https://ultralytics.com/assets/{font.name}' LOGGER.info(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url, str(file), progress=progress) -def check_online(): - # Check internet connectivity +def check_online() -> bool: + """ + Check internet connectivity by attempting to connect to a known online host. + + Returns: + bool: True if connection is successful, False otherwise. + """ import socket try: - socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility + # Check host accessibility by attempting to establish a connection + socket.create_connection(("1.1.1.1", 443), timeout=5) return True except OSError: return False -def check_python(minimum='3.7.0'): - # Check current python version vs. required python version +def check_python(minimum: str = '3.7.0') -> bool: + """ + Check current python version against the required minimum version. + + Args: + minimum (str): Required minimum version of python. + + Returns: + None + """ check_version(platform.python_version(), minimum, name='Python ', hard=True)