Docstring additions (#122)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher
2022-12-31 13:42:45 +01:00
committed by GitHub
parent c9f3e469cb
commit df4fc14c10
10 changed files with 291 additions and 73 deletions

View File

@ -4,11 +4,11 @@ import logging.config
import os
import platform
import sys
import tempfile
import threading
from pathlib import Path
import cv2
import IPython
import pandas as pd
# Constants
@ -25,22 +25,25 @@ TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}' # tqdm bar format
LOGGING_NAME = 'yolov5'
HELP_MSG = \
"""
Please refer to below Usage examples for help running YOLOv8:
Usage examples for running YOLOv8:
1. Install the ultralytics package:
Install:
pip install ultralytics
Python SDK:
2. Use the Python SDK:
from ultralytics import YOLO
model = YOLO.new('yolov8n.yaml') # create a new model from scratch
model = YOLO.load('yolov8n.pt') # load a pretrained model (recommended for best training results)
results = model.train(data='coco128.yaml')
results = model.val()
results = model.predict(source='bus.jpg')
success = model.export(format='onnx')
model = YOLO.new('yolov8n.yaml') # create a new model from scratch
model = YOLO.load('yolov8n.pt') # load a pretrained model (recommended for best training results)
results = model.train(data='coco128.yaml') # train the model
results = model.val() # evaluate model performance on the validation set
results = model.predict(source='bus.jpg') # predict on an image
success = model.export(format='onnx') # export the model to ONNX format
3. Use the command line interface (CLI):
CLI:
yolo task=detect mode=train model=yolov8n.yaml args...
classify predict yolov8n-cls.yaml args...
segment val yolov8n-seg.yaml args...
@ -60,41 +63,67 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
def is_colab():
# Is environment a Google Colab instance?
"""
Check if the current script is running inside a Google Colab notebook.
Returns:
bool: True if running inside a Colab notebook, False otherwise.
"""
# Check if the google.colab module is present in sys.modules
return 'google.colab' in sys.modules
def is_kaggle():
# Is environment a Kaggle Notebook?
"""
Check if the current script is running inside a Kaggle kernel.
Returns:
bool: True if running inside a Kaggle kernel, False otherwise.
"""
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
def is_notebook():
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
ipython_type = str(type(IPython.get_ipython()))
return 'colab' in ipython_type or 'zmqshell' in ipython_type
def is_jupyter_notebook():
"""
Check if the current script is running inside a Jupyter Notebook.
Verified on Colab, Jupyterlab, Kaggle, Paperspace.
def is_docker() -> bool:
"""Check if the process runs inside a docker container."""
if Path("/.dockerenv").exists():
return True
try: # check if docker is in control groups
with open("/proc/self/cgroup") as file:
return any("docker" in line for line in file)
except OSError:
Returns:
bool: True if running inside a Jupyter Notebook, False otherwise.
"""
# Check if the get_ipython function exists
# (it does not exist when running as a standalone script)
try:
from IPython import get_ipython
return get_ipython() is not None
except ImportError:
return False
def is_writeable(dir, test=False):
# Return True if directory has write permissions, test opening a file with write permissions if test=True
if not test:
return os.access(dir, os.W_OK) # possible issues on Windows
file = Path(dir) / 'tmp.txt'
def is_docker() -> bool:
"""
Determine if the script is running inside a Docker container.
Returns:
bool: True if the script is running inside a Docker container, False otherwise.
"""
with open('/proc/self/cgroup') as f:
return 'docker' in f.read()
def is_dir_writeable(dir_path: str) -> bool:
"""
Check if a directory is writeable.
Args:
dir_path (str): The path to the directory.
Returns:
bool: True if the directory is writeable, False otherwise.
"""
try:
with open(file, 'w'): # open file with write permissions
with tempfile.TemporaryFile(dir=dir_path):
pass
file.unlink() # remove file
return True
except OSError:
return False
@ -106,20 +135,40 @@ def get_default_args(func):
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
env = os.getenv(env_var)
if env:
path = Path(env) # use environment variable
def get_user_config_dir(sub_dir='Ultralytics'):
"""
Get the user config directory.
Args:
sub_dir (str): The name of the subdirectory to create.
Returns:
Path: The path to the user config directory.
"""
# Get the operating system name
os_name = platform.system()
# Return the appropriate config directory for each operating system
if os_name == 'Windows':
path = Path.home() / 'AppData' / 'Roaming' / sub_dir
elif os_name == 'Darwin': # macOS
path = Path.home() / 'Library' / 'Application Support' / sub_dir
elif os_name == 'Linux':
path = Path.home() / '.config' / sub_dir
else:
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
path.mkdir(exist_ok=True) # make if required
raise ValueError(f'Unsupported operating system: {os_name}')
# GCP and AWS lambda fix, only /tmp is writeable
if not is_dir_writeable(path.parent):
path = Path('/tmp') / sub_dir
# Create the subdirectory if it does not exist
path.mkdir(parents=True, exist_ok=True)
return path
USER_CONFIG_DIR = user_config_dir() # Ultralytics settings dir
USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir
def emojis(str=''):

View File

@ -12,7 +12,7 @@ import pkg_resources as pkg
import torch
from ultralytics.yolo.utils import (AUTOINSTALL, FONT, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, emojis,
is_docker, is_notebook)
is_docker, is_jupyter_notebook)
from ultralytics.yolo.utils.ops import make_divisible
@ -160,7 +160,7 @@ def check_yaml(file, suffix=('.yaml', '.yml')):
def check_imshow(warn=False):
# Check if environment supports image displays
try:
assert not is_notebook()
assert not is_jupyter_notebook()
assert not is_docker()
cv2.imshow('test', np.zeros((1, 1, 3)))
cv2.waitKey(1)

View File

@ -24,8 +24,21 @@ class WorkingDirectory(contextlib.ContextDecorator):
def increment_path(path, exist_ok=False, sep='', mkdir=False):
"""
Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
# TODO: docs
Increments a file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and exist_ok is not set to True, the path will be incremented by appending a number and sep to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path. If mkdir is set to True, the path will be created as a
directory if it does not already exist.
Args:
path (str or pathlib.Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and will be returned as-is. Defaults to False.
sep (str, optional): Separator to use between the path and the incrementation number. Defaults to an empty string.
mkdir (bool, optional): If True, the path will be created as a directory if it does not exist. Defaults to False.
Returns:
pathlib.Path: Incremented path.
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:

View File

@ -100,10 +100,31 @@ def non_max_suppression(
max_det=300,
nm=0, # number of masks
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
"""
Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
Arguments:
prediction (torch.Tensor): A tensor of shape (batch_size, num_boxes, num_classes + 4 + num_masks)
containing the predicted boxes, classes, and masks. The tensor should be in the format
output by a model, such as YOLO.
conf_thres (float): The confidence threshold below which boxes will be filtered out.
Valid values are between 0.0 and 1.0.
iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
Valid values are between 0.0 and 1.0.
classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
agnostic (bool): If True, the model is agnostic to the number of classes, and all
classes will be considered as one.
multi_label (bool): If True, each box may have multiple labels.
labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
list contains the apriori labels for a given image. The list should be in the format
output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
max_det (int): The maximum number of boxes to keep after NMS.
nm (int): The number of masks output by the model.
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
List[torch.Tensor]: A list of length batch_size, where each element is a tensor of
shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
(x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
"""
# Checks