From df4fc14c109a8b0851bb54ada9588d42a805ec64 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 31 Dec 2022 13:42:45 +0100 Subject: [PATCH] Docstring additions (#122) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ultralytics/yolo/data/augment.py | 4 +- ultralytics/yolo/engine/exporter.py | 23 ++++- ultralytics/yolo/engine/model.py | 37 ++++---- ultralytics/yolo/engine/predictor.py | 30 +++++- ultralytics/yolo/engine/trainer.py | 61 +++++++++++- ultralytics/yolo/engine/validator.py | 28 +++++- ultralytics/yolo/utils/__init__.py | 133 ++++++++++++++++++--------- ultralytics/yolo/utils/checks.py | 4 +- ultralytics/yolo/utils/files.py | 17 +++- ultralytics/yolo/utils/ops.py | 25 ++++- 10 files changed, 290 insertions(+), 72 deletions(-) diff --git a/ultralytics/yolo/data/augment.py b/ultralytics/yolo/data/augment.py index b33236a..9636b1e 100644 --- a/ultralytics/yolo/data/augment.py +++ b/ultralytics/yolo/data/augment.py @@ -436,7 +436,9 @@ class LetterBox: self.scaleup = scaleup self.stride = stride - def __call__(self, labels={}, image=None): + def __call__(self, labels=None, image=None): + if labels is None: + labels = {} img = labels.get("img") if image is None else image shape = img.shape[:2] # current shape [height, width] new_shape = labels.pop("rect_shape", self.new_shape) diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index dba30bb..c20ba22 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -116,14 +116,31 @@ def try_export(inner_func): class Exporter: + """ + Exporter + + A class for exporting a model. + + Attributes: + args (OmegaConf): Configuration for the exporter. + save_dir (Path): Directory to save results. + """ + + def __init__(self, config=DEFAULT_CONFIG, overrides=None): + """ + Initializes the Exporter class. - def __init__(self, config=DEFAULT_CONFIG, overrides={}): + Args: + cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + if overrides is None: + overrides = {} self.args = get_config(config, overrides) project = self.args.project or f"runs/{self.args.task}" name = self.args.name or "exp" # hardcode mode as export doesn't require it self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok) self.save_dir.mkdir(parents=True, exist_ok=True) - self.imgsz = self.args.imgsz @smart_inference_mode() def __call__(self, model=None): @@ -143,7 +160,7 @@ class Exporter: assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic' # Checks - self.imgsz = check_imgsz(self.imgsz, stride=model.stride, min_dim=2) # check image size + self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size if self.args.optimize: assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index 80b94b8..ada350f 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,6 +1,6 @@ import torch -from ultralytics import yolo # noqa required for python usage +from ultralytics import yolo # noqa from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights from ultralytics.yolo.configs import get_config from ultralytics.yolo.engine.exporter import Exporter @@ -9,7 +9,7 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_yaml from ultralytics.yolo.utils.files import yaml_load from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode -# map head: [model, trainer, validator, predictor] +# Map head to model, trainer, validator, and predictor classes MODEL_MAP = { "classify": [ ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator', @@ -24,39 +24,44 @@ MODEL_MAP = { class YOLO: """ - Python interface which emulates a model-like behaviour by wrapping trainers. + YOLO + + A python interface which emulates a model-like behaviour by wrapping trainers. """ - __init_key = object() + __init_key = object() # used to ensure proper initialization def __init__(self, init_key=None, type="v8") -> None: """ + Initializes the YOLO object. + Args: - type (str): Type/version of models to use + init_key (object): used to ensure proper initialization. Defaults to None. + type (str): Type/version of models to use. Defaults to "v8". """ if init_key != YOLO.__init_key: raise SyntaxError(HELP_MSG) self.type = type - self.ModelClass = None - self.TrainerClass = None - self.ValidatorClass = None - self.PredictorClass = None - self.model = None - self.trainer = None - self.task = None + self.ModelClass = None # model class + self.TrainerClass = None # trainer class + self.ValidatorClass = None # validator class + self.PredictorClass = None # predictor class + self.model = None # model object + self.trainer = None # trainer object + self.task = None # task type self.ckpt = None # if loaded from *.pt self.cfg = None # if loaded from *.yaml - self.overrides = {} - self.init_disabled = False + self.overrides = {} # overrides for trainer object + self.init_disabled = False # disable model initialization @classmethod def new(cls, cfg: str, verbose=True): """ - Initializes a new model and infers the task type from the model definitions + Initializes a new model and infers the task type from the model definitions. Args: cfg (str): model configuration file - verbsoe (bool): display model info on load + verbose (bool): display model info on load """ cfg = check_yaml(cfg) # check YAML cfg_dict = yaml_load(cfg) # model dict diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index b26002c..3fe3c23 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -41,8 +41,36 @@ from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mo class BasePredictor: + """ + BasePredictor + + A base class for creating predictors. + + Attributes: + args (OmegaConf): Configuration for the predictor. + save_dir (Path): Directory to save results. + done_setup (bool): Whether the predictor has finished setup. + model (nn.Module): Model used for prediction. + data (dict): Data configuration. + device (torch.device): Device used for prediction. + dataset (Dataset): Dataset used for prediction. + vid_path (str): Path to video file. + vid_writer (cv2.VideoWriter): Video writer for saving video output. + view_img (bool): Whether to view image output. + annotator (Annotator): Annotator used for prediction. + data_path (str): Path to data. + """ + + def __init__(self, config=DEFAULT_CONFIG, overrides=None): + """ + Initializes the BasePredictor class. - def __init__(self, config=DEFAULT_CONFIG, overrides={}): + Args: + config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + if overrides is None: + overrides = {} self.args = get_config(config, overrides) project = self.args.project or f"runs/{self.args.task}" name = self.args.name or f"{self.args.mode}" diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 8c751e3..7f43554 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -33,9 +33,53 @@ from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds class BaseTrainer: + """ + BaseTrainer + + A base class for creating trainers. + + Attributes: + args (OmegaConf): Configuration for the trainer. + check_resume (method): Method to check if training should be resumed from a saved checkpoint. + console (logging.Logger): Logger instance. + validator (BaseValidator): Validator instance. + model (nn.Module): Model instance. + callbacks (defaultdict): Dictionary of callbacks. + save_dir (Path): Directory to save results. + wdir (Path): Directory to save weights. + last (Path): Path to last checkpoint. + best (Path): Path to best checkpoint. + batch_size (int): Batch size for training. + epochs (int): Number of epochs to train for. + start_epoch (int): Starting epoch for training. + device (torch.device): Device to use for training. + amp (bool): Flag to enable AMP (Automatic Mixed Precision). + scaler (amp.GradScaler): Gradient scaler for AMP. + data (str): Path to data. + trainset (torch.utils.data.Dataset): Training dataset. + testset (torch.utils.data.Dataset): Testing dataset. + ema (nn.Module): EMA (Exponential Moving Average) of the model. + lf (nn.Module): Loss function. + scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler. + best_fitness (float): The best fitness value achieved. + fitness (float): Current fitness value. + loss (float): Current loss value. + tloss (float): Total loss value. + loss_names (list): List of loss names. + csv (Path): Path to results CSV file. + """ + + def __init__(self, config=DEFAULT_CONFIG, overrides=None): + """ + Initializes the BaseTrainer class. - def __init__(self, cfg=DEFAULT_CONFIG, overrides={}): - self.args = get_config(cfg, overrides) + Args: + config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG. + overrides (dict, optional): Configuration overrides. Defaults to None. + """ + if overrides is None: + overrides = {} + self.args = get_config(config, overrides) self.check_resume() init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) @@ -464,6 +508,19 @@ class BaseTrainer: @staticmethod def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): + """ + Builds an optimizer with the specified parameters and parameter groups. + + Args: + model (nn.Module): model to optimize + name (str): name of the optimizer to use + lr (float): learning rate + momentum (float): momentum + decay (float): weight decay + + Returns: + torch.optim.Optimizer: the built optimizer + """ g = [], [], [] # optimizer parameter groups bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() for v in model.modules(): diff --git a/ultralytics/yolo/engine/validator.py b/ultralytics/yolo/engine/validator.py index 75c722d..7f6d8b3 100644 --- a/ultralytics/yolo/engine/validator.py +++ b/ultralytics/yolo/engine/validator.py @@ -16,10 +16,36 @@ from ultralytics.yolo.utils.torch_utils import de_parallel, select_device, smart class BaseValidator: """ - Base validator class. + BaseValidator + + A base class for creating validators. + + Attributes: + dataloader (DataLoader): Dataloader to use for validation. + pbar (tqdm): Progress bar to update during validation. + logger (logging.Logger): Logger to use for validation. + args (OmegaConf): Configuration for the validator. + model (nn.Module): Model to validate. + data (dict): Data dictionary. + device (torch.device): Device to use for validation. + batch_i (int): Current batch index. + training (bool): Whether the model is in training mode. + speed (float): Batch processing speed in seconds. + jdict (dict): Dictionary to store validation results. + save_dir (Path): Directory to save results. """ def __init__(self, dataloader=None, save_dir=None, pbar=None, logger=None, args=None): + """ + Initializes a BaseValidator instance. + + Args: + dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation. + save_dir (Path): Directory to save results. + pbar (tqdm.tqdm): Progress bar for displaying progress. + logger (logging.Logger): Logger to log messages. + args (OmegaConf): Configuration for the validator. + """ self.dataloader = dataloader self.pbar = pbar self.logger = logger or LOGGER diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 21b4e1f..81ff33c 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -4,11 +4,11 @@ import logging.config import os import platform import sys +import tempfile import threading from pathlib import Path import cv2 -import IPython import pandas as pd # Constants @@ -25,22 +25,25 @@ TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format LOGGING_NAME = 'yolov5' HELP_MSG = \ """ - Please refer to below Usage examples for help running YOLOv8: + Usage examples for running YOLOv8: + + 1. Install the ultralytics package: - Install: pip install ultralytics - Python SDK: + 2. Use the Python SDK: + from ultralytics import YOLO - model = YOLO.new('yolov8n.yaml') # create a new model from scratch - model = YOLO.load('yolov8n.pt') # load a pretrained model (recommended for best training results) - results = model.train(data='coco128.yaml') - results = model.val() - results = model.predict(source='bus.jpg') - success = model.export(format='onnx') + model = YOLO.new('yolov8n.yaml') # create a new model from scratch + model = YOLO.load('yolov8n.pt') # load a pretrained model (recommended for best training results) + results = model.train(data='coco128.yaml') # train the model + results = model.val() # evaluate model performance on the validation set + results = model.predict(source='bus.jpg') # predict on an image + success = model.export(format='onnx') # export the model to ONNX format + + 3. Use the command line interface (CLI): - CLI: yolo task=detect mode=train model=yolov8n.yaml args... classify predict yolov8n-cls.yaml args... segment val yolov8n-seg.yaml args... @@ -60,41 +63,67 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads def is_colab(): - # Is environment a Google Colab instance? + """ + Check if the current script is running inside a Google Colab notebook. + + Returns: + bool: True if running inside a Colab notebook, False otherwise. + """ + # Check if the google.colab module is present in sys.modules return 'google.colab' in sys.modules def is_kaggle(): - # Is environment a Kaggle Notebook? + """ + Check if the current script is running inside a Kaggle kernel. + + Returns: + bool: True if running inside a Kaggle kernel, False otherwise. + """ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' -def is_notebook(): - # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace - ipython_type = str(type(IPython.get_ipython())) - return 'colab' in ipython_type or 'zmqshell' in ipython_type +def is_jupyter_notebook(): + """ + Check if the current script is running inside a Jupyter Notebook. + Verified on Colab, Jupyterlab, Kaggle, Paperspace. + + Returns: + bool: True if running inside a Jupyter Notebook, False otherwise. + """ + # Check if the get_ipython function exists + # (it does not exist when running as a standalone script) + try: + from IPython import get_ipython + return get_ipython() is not None + except ImportError: + return False def is_docker() -> bool: - """Check if the process runs inside a docker container.""" - if Path("/.dockerenv").exists(): - return True - try: # check if docker is in control groups - with open("/proc/self/cgroup") as file: - return any("docker" in line for line in file) - except OSError: - return False + """ + Determine if the script is running inside a Docker container. + + Returns: + bool: True if the script is running inside a Docker container, False otherwise. + """ + with open('/proc/self/cgroup') as f: + return 'docker' in f.read() -def is_writeable(dir, test=False): - # Return True if directory has write permissions, test opening a file with write permissions if test=True - if not test: - return os.access(dir, os.W_OK) # possible issues on Windows - file = Path(dir) / 'tmp.txt' +def is_dir_writeable(dir_path: str) -> bool: + """ + Check if a directory is writeable. + + Args: + dir_path (str): The path to the directory. + + Returns: + bool: True if the directory is writeable, False otherwise. + """ try: - with open(file, 'w'): # open file with write permissions + with tempfile.TemporaryFile(dir=dir_path): pass - file.unlink() # remove file return True except OSError: return False @@ -106,20 +135,40 @@ def get_default_args(func): return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} -def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): - # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. - env = os.getenv(env_var) - if env: - path = Path(env) # use environment variable +def get_user_config_dir(sub_dir='Ultralytics'): + """ + Get the user config directory. + + Args: + sub_dir (str): The name of the subdirectory to create. + + Returns: + Path: The path to the user config directory. + """ + # Get the operating system name + os_name = platform.system() + + # Return the appropriate config directory for each operating system + if os_name == 'Windows': + path = Path.home() / 'AppData' / 'Roaming' / sub_dir + elif os_name == 'Darwin': # macOS + path = Path.home() / 'Library' / 'Application Support' / sub_dir + elif os_name == 'Linux': + path = Path.home() / '.config' / sub_dir else: - cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs - path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir - path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable - path.mkdir(exist_ok=True) # make if required + raise ValueError(f'Unsupported operating system: {os_name}') + + # GCP and AWS lambda fix, only /tmp is writeable + if not is_dir_writeable(path.parent): + path = Path('/tmp') / sub_dir + + # Create the subdirectory if it does not exist + path.mkdir(parents=True, exist_ok=True) + return path -USER_CONFIG_DIR = user_config_dir() # Ultralytics settings dir +USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir def emojis(str=''): diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index 1135ae7..4699a3d 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -12,7 +12,7 @@ import pkg_resources as pkg import torch from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis, - is_docker, is_notebook) + is_docker, is_jupyter_notebook) from ultralytics.yolo.utils.ops import make_divisible @@ -160,7 +160,7 @@ def check_yaml(file, suffix=('.yaml', '.yml')): def check_imshow(warn=False): # Check if environment supports image displays try: - assert not is_notebook() + assert not is_jupyter_notebook() assert not is_docker() cv2.imshow('test', np.zeros((1, 1, 3))) cv2.waitKey(1) diff --git a/ultralytics/yolo/utils/files.py b/ultralytics/yolo/utils/files.py index 7d0e5e3..e185226 100644 --- a/ultralytics/yolo/utils/files.py +++ b/ultralytics/yolo/utils/files.py @@ -24,8 +24,21 @@ class WorkingDirectory(contextlib.ContextDecorator): def increment_path(path, exist_ok=False, sep='', mkdir=False): """ - Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. - # TODO: docs + Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + + If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to + the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the + number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a + directory if it does not already exist. + + Args: + path (str or pathlib.Path): Path to increment. + exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False. + sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string. + mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False. + + Returns: + pathlib.Path: Incremented path. """ path = Path(path) # os-agnostic if path.exists() and not exist_ok: diff --git a/ultralytics/yolo/utils/ops.py b/ultralytics/yolo/utils/ops.py index 55efc10..41b0db0 100644 --- a/ultralytics/yolo/utils/ops.py +++ b/ultralytics/yolo/utils/ops.py @@ -100,10 +100,31 @@ def non_max_suppression( max_det=300, nm=0, # number of masks ): - """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections + """ + Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box. + + Arguments: + prediction (torch.Tensor): A tensor of shape (batch_size, num_boxes, num_classes + 4 + num_masks) + containing the predicted boxes, classes, and masks. The tensor should be in the format + output by a model, such as YOLO. + conf_thres (float): The confidence threshold below which boxes will be filtered out. + Valid values are between 0.0 and 1.0. + iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS. + Valid values are between 0.0 and 1.0. + classes (List[int]): A list of class indices to consider. If None, all classes will be considered. + agnostic (bool): If True, the model is agnostic to the number of classes, and all + classes will be considered as one. + multi_label (bool): If True, each box may have multiple labels. + labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner + list contains the apriori labels for a given image. The list should be in the format + output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2). + max_det (int): The maximum number of boxes to keep after NMS. + nm (int): The number of masks output by the model. Returns: - list of detections, on (n,6) tensor per image [xyxy, conf, cls] + List[torch.Tensor]: A list of length batch_size, where each element is a tensor of + shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns + (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). """ # Checks