Improvements (#142)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
		| @ -73,7 +73,7 @@ from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages | ||||
| from ultralytics.yolo.data.utils import check_dataset | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, get_default_args, yaml_save | ||||
| from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml | ||||
| from ultralytics.yolo.utils.files import file_size, increment_path | ||||
| from ultralytics.yolo.utils.files import file_size | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| from ultralytics.yolo.utils.torch_utils import guess_task_from_head, select_device, smart_inference_mode | ||||
|  | ||||
| @ -138,10 +138,6 @@ class Exporter: | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         self.args = get_config(config, overrides) | ||||
|         project = self.args.project or f"runs/{self.args.task}" | ||||
|         name = self.args.name or "exp"  # hardcode mode as export doesn't require it | ||||
|         self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) | ||||
|         self.save_dir.mkdir(parents=True, exist_ok=True) | ||||
|         self.callbacks = defaultdict(list, {k: [v] for k, v in callbacks.default_callbacks.items()})  # add callbacks | ||||
|         callbacks.add_integration_callbacks(self) | ||||
|  | ||||
|  | ||||
| @ -35,7 +35,7 @@ from ultralytics.nn.autobackend import AutoBackend | ||||
| from ultralytics.yolo.configs import get_config | ||||
| from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages, LoadScreenshots, LoadStreams | ||||
| from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, callbacks, colorstr, ops | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, SETTINGS, callbacks, colorstr, ops | ||||
| from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_imshow | ||||
| from ultralytics.yolo.utils.files import increment_path | ||||
| from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode | ||||
| @ -73,7 +73,7 @@ class BasePredictor: | ||||
|         if overrides is None: | ||||
|             overrides = {} | ||||
|         self.args = get_config(config, overrides) | ||||
|         project = self.args.project or f"runs/{self.args.task}" | ||||
|         project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task | ||||
|         name = self.args.name or f"{self.args.mode}" | ||||
|         self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) | ||||
|         (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True) | ||||
|  | ||||
| @ -25,7 +25,8 @@ import ultralytics.yolo.utils as utils | ||||
| from ultralytics import __version__ | ||||
| from ultralytics.yolo.configs import get_config | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks, colorstr, yaml_save | ||||
| from ultralytics.yolo.utils import (DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, | ||||
|                                     yaml_save) | ||||
| from ultralytics.yolo.utils.checks import check_file, print_args | ||||
| from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command | ||||
| from ultralytics.yolo.utils.files import get_latest_run, increment_path | ||||
| @ -88,7 +89,7 @@ class BaseTrainer: | ||||
|         init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) | ||||
|  | ||||
|         # Dirs | ||||
|         project = self.args.project or f"runs/{self.args.task}" | ||||
|         project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task | ||||
|         name = self.args.name or f"{self.args.mode}" | ||||
|         self.save_dir = Path( | ||||
|             self.args.get( | ||||
|  | ||||
| @ -8,7 +8,7 @@ from tqdm import tqdm | ||||
|  | ||||
| from ultralytics.nn.autobackend import AutoBackend | ||||
| from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, TQDM_BAR_FORMAT, callbacks | ||||
| from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks | ||||
| from ultralytics.yolo.utils.checks import check_imgsz | ||||
| from ultralytics.yolo.utils.files import increment_path | ||||
| from ultralytics.yolo.utils.ops import Profile | ||||
| @ -59,7 +59,7 @@ class BaseValidator: | ||||
|         self.speed = None | ||||
|         self.jdict = None | ||||
|  | ||||
|         project = self.args.project or f"runs/{self.args.task}" | ||||
|         project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task | ||||
|         name = self.args.name or f"{self.args.mode}" | ||||
|         self.save_dir = save_dir or increment_path(Path(project) / name, | ||||
|                                                    exist_ok=self.args.exist_ok if RANK in {-1, 0} else True) | ||||
|  | ||||
| @ -18,7 +18,6 @@ FILE = Path(__file__).resolve() | ||||
| ROOT = FILE.parents[2]  # YOLO | ||||
| DEFAULT_CONFIG = ROOT / "yolo/configs/default.yaml" | ||||
| RANK = int(os.getenv('RANK', -1)) | ||||
| DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets'))  # global datasets directory | ||||
| NUM_THREADS = min(8, max(1, os.cpu_count() - 1))  # number of YOLOv5 multiprocessing threads | ||||
| AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true'  # global auto-install mode | ||||
| FONT = 'Arial.ttf'  # https://ultralytics.com/assets/Arial.ttf | ||||
| @ -119,6 +118,41 @@ def is_docker() -> bool: | ||||
|         return 'docker' in f.read() | ||||
|  | ||||
|  | ||||
| def is_git_directory() -> bool: | ||||
|     """ | ||||
|     Check if the current working directory is inside a git repository. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if the current working directory is inside a git repository, False otherwise. | ||||
|     """ | ||||
|     from git import Repo | ||||
|     try: | ||||
|         # Check if the current working directory is a git repository | ||||
|         Repo(search_parent_directories=True) | ||||
|         return True | ||||
|     except Exception: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def is_pip_package(filepath: str = __name__) -> bool: | ||||
|     """ | ||||
|     Determines if the file at the given filepath is part of a pip package. | ||||
|  | ||||
|     Args: | ||||
|         filepath (str): The filepath to check. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if the file is part of a pip package, False otherwise. | ||||
|     """ | ||||
|     import importlib.util | ||||
|  | ||||
|     # Get the spec for the module | ||||
|     spec = importlib.util.find_spec(filepath) | ||||
|  | ||||
|     # Return whether the spec is not None and the origin is not None (indicating it is a package) | ||||
|     return spec is not None and spec.origin is not None | ||||
|  | ||||
|  | ||||
| def is_dir_writeable(dir_path: str) -> bool: | ||||
|     """ | ||||
|     Check if a directory is writeable. | ||||
| @ -305,10 +339,11 @@ def get_settings(file=USER_CONFIG_DIR / 'settings.yaml'): | ||||
|     """ | ||||
|     from ultralytics.yolo.utils.torch_utils import torch_distributed_zero_first | ||||
|  | ||||
|     git_install = not is_pip_package() | ||||
|     defaults = { | ||||
|         'datasets_dir': None,  # default datasets directory. If None, current working directory is used. | ||||
|         'weights_dir': None,  # default weights directory. If None, current working directory is used. | ||||
|         'runs_dir': None,  # default runs directory. If None, current working directory is used. | ||||
|         'datasets_dir': str(ROOT / 'datasets') if git_install else 'datasets',  # default datasets directory. | ||||
|         'weights_dir': str(ROOT / 'weights') if git_install else 'weights',  # default weights directory. | ||||
|         'runs_dir': str(ROOT / 'runs') if git_install else 'runs',  # default runs directory. | ||||
|         'sync': True,  # sync analytics to help with YOLO development | ||||
|         'uuid': uuid.getnode(),  # device UUID to align analytics | ||||
|         'yaml_file': str(file)}  # setting YAML file path | ||||
| @ -336,6 +371,7 @@ if platform.system() == 'Windows': | ||||
|  | ||||
| # Check first-install steps | ||||
| SETTINGS = get_settings() | ||||
| DATASETS_DIR = Path(SETTINGS['datasets_dir'])  # global datasets directory | ||||
|  | ||||
|  | ||||
| def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'): | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| import glob | ||||
| import inspect | ||||
| import math | ||||
| import platform | ||||
| import urllib | ||||
| from pathlib import Path | ||||
| @ -13,71 +14,141 @@ import torch | ||||
|  | ||||
| from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, | ||||
|                                     is_docker, is_jupyter_notebook) | ||||
| from ultralytics.yolo.utils.ops import make_divisible | ||||
|  | ||||
|  | ||||
| def is_ascii(s=''): | ||||
|     # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) | ||||
|     s = str(s)  # convert list, tuple, None, etc. to str | ||||
|     return len(s.encode().decode('ascii', 'ignore')) == len(s) | ||||
| def is_ascii(s) -> bool: | ||||
|     """ | ||||
|     Check if a string is composed of only ASCII characters. | ||||
|  | ||||
|     Args: | ||||
|         s (str): String to be checked. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if the string is composed only of ASCII characters, False otherwise. | ||||
|     """ | ||||
|     # Convert list, tuple, None, etc. to string | ||||
|     s = str(s) | ||||
|  | ||||
|     # Check if the string is composed of only ASCII characters | ||||
|     return all(ord(c) < 128 for c in s) | ||||
|  | ||||
|  | ||||
| def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): | ||||
|     # Verify image size is a multiple of stride s in each dimension | ||||
|     """ | ||||
|     Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the | ||||
|     stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. | ||||
|  | ||||
|     Args: | ||||
|         imgsz (int or List[int]): Image size. | ||||
|         stride (int): Stride value. | ||||
|         min_dim (int): Minimum number of dimensions. | ||||
|         floor (int): Minimum allowed value for image size. | ||||
|  | ||||
|     Returns: | ||||
|         List[int]: Updated image size. | ||||
|     """ | ||||
|     # Convert stride to integer if it is a tensor | ||||
|     stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) | ||||
|     if isinstance(imgsz, int):  # integer i.e. imgsz=640 | ||||
|         sz = max(make_divisible(imgsz, stride), floor) | ||||
|     else:  # list i.e. imgsz=[640, 480] | ||||
|         imgsz = list(imgsz)  # convert to list if tuple | ||||
|         sz = [max(make_divisible(x, stride), floor) for x in imgsz] | ||||
|  | ||||
|     # Convert image size to list if it is an integer | ||||
|     if isinstance(imgsz, int): | ||||
|         imgsz = [imgsz] | ||||
|  | ||||
|     # Make image size a multiple of the stride | ||||
|     sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] | ||||
|  | ||||
|     # Print warning message if image size was updated | ||||
|     if sz != imgsz: | ||||
|         LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {stride}, updating to {sz}') | ||||
|  | ||||
|     # Check dims | ||||
|     if min_dim == 2: | ||||
|         if isinstance(imgsz, int): | ||||
|             sz = [sz, sz] | ||||
|         elif len(sz) == 1: | ||||
|     # Add missing dimensions if necessary | ||||
|     if min_dim == 2 and len(sz) == 1: | ||||
|         sz = [sz[0], sz[0]] | ||||
|  | ||||
|     return sz | ||||
|  | ||||
|  | ||||
| def check_version(current="0.0.0", minimum="0.0.0", name="version ", pinned=False, hard=False, verbose=False): | ||||
|     # Check version vs. required version | ||||
|     current, minimum = (pkg.parse_version(x) for x in (current, minimum)) | ||||
| def check_version(current: str = "0.0.0", | ||||
|                   minimum: str = "0.0.0", | ||||
|                   name: str = "version ", | ||||
|                   pinned: bool = False, | ||||
|                   hard: bool = False, | ||||
|                   verbose: bool = False) -> bool: | ||||
|     """ | ||||
|     Check current version against the required minimum version. | ||||
|  | ||||
|     Args: | ||||
|         current (str): Current version. | ||||
|         minimum (str): Required minimum version. | ||||
|         name (str): Name to be used in warning message. | ||||
|         pinned (bool): If True, versions must match exactly. If False, minimum version must be satisfied. | ||||
|         hard (bool): If True, raise an AssertionError if the minimum version is not met. | ||||
|         verbose (bool): If True, print warning message if minimum version is not met. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if minimum version is met, False otherwise. | ||||
|     """ | ||||
|     from pkg_resources import parse_version | ||||
|     current, minimum = (parse_version(x) for x in (current, minimum)) | ||||
|     result = (current == minimum) if pinned else (current >= minimum)  # bool | ||||
|     s = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed"  # string | ||||
|     warning_message = f"WARNING ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed" | ||||
|     if hard: | ||||
|         assert result, emojis(s)  # assert min requirements met | ||||
|         assert result, emojis(warning_message)  # assert min requirements met | ||||
|     if verbose and not result: | ||||
|         LOGGER.warning(s) | ||||
|         LOGGER.warning(warning_message) | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def check_font(font=FONT, progress=False): | ||||
|     # Download font to CONFIG_DIR if necessary | ||||
| def check_font(font: str = FONT, progress: bool = False) -> None: | ||||
|     """ | ||||
|     Download font file to the user's configuration directory if it does not already exist. | ||||
|  | ||||
|     Args: | ||||
|         font (str): Path to font file. | ||||
|         progress (bool): If True, display a progress bar during the download. | ||||
|  | ||||
|     Returns: | ||||
|         None | ||||
|     """ | ||||
|     font = Path(font) | ||||
|  | ||||
|     # Destination path for the font file | ||||
|     file = USER_CONFIG_DIR / font.name | ||||
|  | ||||
|     # Check if font file exists at the source or destination path | ||||
|     if not font.exists() and not file.exists(): | ||||
|         # Download font file | ||||
|         url = f'https://ultralytics.com/assets/{font.name}' | ||||
|         LOGGER.info(f'Downloading {url} to {file}...') | ||||
|         torch.hub.download_url_to_file(url, str(file), progress=progress) | ||||
|  | ||||
|  | ||||
| def check_online(): | ||||
|     # Check internet connectivity | ||||
| def check_online() -> bool: | ||||
|     """ | ||||
|     Check internet connectivity by attempting to connect to a known online host. | ||||
|  | ||||
|     Returns: | ||||
|         bool: True if connection is successful, False otherwise. | ||||
|     """ | ||||
|     import socket | ||||
|     try: | ||||
|         socket.create_connection(("1.1.1.1", 443), 5)  # check host accessibility | ||||
|         # Check host accessibility by attempting to establish a connection | ||||
|         socket.create_connection(("1.1.1.1", 443), timeout=5) | ||||
|         return True | ||||
|     except OSError: | ||||
|         return False | ||||
|  | ||||
|  | ||||
| def check_python(minimum='3.7.0'): | ||||
|     # Check current python version vs. required python version | ||||
| def check_python(minimum: str = '3.7.0') -> bool: | ||||
|     """ | ||||
|     Check current python version against the required minimum version. | ||||
|  | ||||
|     Args: | ||||
|         minimum (str): Required minimum version of python. | ||||
|  | ||||
|     Returns: | ||||
|         None | ||||
|     """ | ||||
|     check_version(platform.python_version(), minimum, name='Python ', hard=True) | ||||
|  | ||||
|  | ||||
|  | ||||
		Reference in New Issue
	
	Block a user